MCPcopy
hub / github.com/FoundationVision/LlamaGen / load_model

Function load_model

app.py:25–50  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

23}
24
25def load_model(args):
26 ckpt_folder = "./"
27 vq_ckpt, gpt_ckpt, image_size = model2ckpt[args.gpt_model]
28 hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=vq_ckpt, local_dir=ckpt_folder)
29 hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=gpt_ckpt, local_dir=ckpt_folder)
30 # create and load model
31 vq_model = VQ_models[args.vq_model](
32 codebook_size=args.codebook_size,
33 codebook_embed_dim=args.codebook_embed_dim)
34 vq_model.to(device)
35 vq_model.eval()
36 checkpoint = torch.load(f"{ckpt_folder}{vq_ckpt}", map_location="cpu")
37 vq_model.load_state_dict(checkpoint["model"])
38 del checkpoint
39 print(f"image tokenizer is loaded")
40
41 # Create an LLM.
42 args.image_size = image_size
43 args.gpt_ckpt = f"{ckpt_folder}{gpt_ckpt}"
44 llm = LLM(
45 args=args,
46 model='serve/fake_json/{}.json'.format(args.gpt_model),
47 gpu_memory_utilization=0.6,
48 skip_tokenizer_init=True)
49 print(f"gpt model is loaded")
50 return vq_model, llm, image_size
51
52
53def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):

Callers 1

app.pyFile · 0.70

Calls 3

LLMClass · 0.90
printFunction · 0.85
loadMethod · 0.80

Tested by

no test coverage detected