| 23 | } |
| 24 | |
| 25 | def load_model(args): |
| 26 | ckpt_folder = "./" |
| 27 | vq_ckpt, gpt_ckpt, image_size = model2ckpt[args.gpt_model] |
| 28 | hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=vq_ckpt, local_dir=ckpt_folder) |
| 29 | hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=gpt_ckpt, local_dir=ckpt_folder) |
| 30 | # create and load model |
| 31 | vq_model = VQ_models[args.vq_model]( |
| 32 | codebook_size=args.codebook_size, |
| 33 | codebook_embed_dim=args.codebook_embed_dim) |
| 34 | vq_model.to(device) |
| 35 | vq_model.eval() |
| 36 | checkpoint = torch.load(f"{ckpt_folder}{vq_ckpt}", map_location="cpu") |
| 37 | vq_model.load_state_dict(checkpoint["model"]) |
| 38 | del checkpoint |
| 39 | print(f"image tokenizer is loaded") |
| 40 | |
| 41 | # Create an LLM. |
| 42 | args.image_size = image_size |
| 43 | args.gpt_ckpt = f"{ckpt_folder}{gpt_ckpt}" |
| 44 | llm = LLM( |
| 45 | args=args, |
| 46 | model='serve/fake_json/{}.json'.format(args.gpt_model), |
| 47 | gpu_memory_utilization=0.6, |
| 48 | skip_tokenizer_init=True) |
| 49 | print(f"gpt model is loaded") |
| 50 | return vq_model, llm, image_size |
| 51 | |
| 52 | |
| 53 | def infer(cfg_scale, top_k, top_p, temperature, class_label, seed): |