(args)
| 18 | |
| 19 | |
| 20 | def main(args): |
| 21 | # Setup PyTorch: |
| 22 | torch.manual_seed(args.seed) |
| 23 | torch.backends.cudnn.deterministic = True |
| 24 | torch.backends.cudnn.benchmark = False |
| 25 | torch.set_grad_enabled(False) |
| 26 | device = "cuda" if torch.cuda.is_available() else "cpu" |
| 27 | |
| 28 | # create and load model |
| 29 | vq_model = VQ_models[args.vq_model]( |
| 30 | codebook_size=args.codebook_size, |
| 31 | codebook_embed_dim=args.codebook_embed_dim) |
| 32 | vq_model.to(device) |
| 33 | vq_model.eval() |
| 34 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu") |
| 35 | vq_model.load_state_dict(checkpoint["model"]) |
| 36 | del checkpoint |
| 37 | print(f"image tokenizer is loaded") |
| 38 | |
| 39 | # create and load gpt model |
| 40 | precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] |
| 41 | latent_size = args.image_size // args.downsample_size |
| 42 | gpt_model = GPT_models[args.gpt_model]( |
| 43 | block_size=latent_size ** 2, |
| 44 | cls_token_num=args.cls_token_num, |
| 45 | model_type=args.gpt_type, |
| 46 | ).to(device=device, dtype=precision) |
| 47 | |
| 48 | checkpoint = torch.load(args.gpt_ckpt, map_location="cpu") |
| 49 | |
| 50 | if "model" in checkpoint: # ddp |
| 51 | model_weight = checkpoint["model"] |
| 52 | elif "module" in checkpoint: # deepspeed |
| 53 | model_weight = checkpoint["module"] |
| 54 | elif "state_dict" in checkpoint: |
| 55 | model_weight = checkpoint["state_dict"] |
| 56 | else: |
| 57 | raise Exception("please check model weight") |
| 58 | gpt_model.load_state_dict(model_weight, strict=False) |
| 59 | gpt_model.eval() |
| 60 | del checkpoint |
| 61 | print(f"gpt model is loaded") |
| 62 | |
| 63 | if args.compile: |
| 64 | print(f"compiling the model...") |
| 65 | gpt_model = torch.compile( |
| 66 | gpt_model, |
| 67 | mode="reduce-overhead", |
| 68 | fullgraph=True |
| 69 | ) # requires PyTorch 2.0 (optional) |
| 70 | else: |
| 71 | print(f"no need to compile model in demo") |
| 72 | |
| 73 | assert os.path.exists(args.t5_path) |
| 74 | t5_model = T5Embedder( |
| 75 | device=device, |
| 76 | local_cache=True, |
| 77 | cache_dir=args.t5_path, |
no test coverage detected