MCPcopy
hub / github.com/FoundationVision/LlamaGen / main

Function main

autoregressive/sample/sample_t2i.py:20–126  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

18
19
20def main(args):
21 # Setup PyTorch:
22 torch.manual_seed(args.seed)
23 torch.backends.cudnn.deterministic = True
24 torch.backends.cudnn.benchmark = False
25 torch.set_grad_enabled(False)
26 device = "cuda" if torch.cuda.is_available() else "cpu"
27
28 # create and load model
29 vq_model = VQ_models[args.vq_model](
30 codebook_size=args.codebook_size,
31 codebook_embed_dim=args.codebook_embed_dim)
32 vq_model.to(device)
33 vq_model.eval()
34 checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
35 vq_model.load_state_dict(checkpoint["model"])
36 del checkpoint
37 print(f"image tokenizer is loaded")
38
39 # create and load gpt model
40 precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
41 latent_size = args.image_size // args.downsample_size
42 gpt_model = GPT_models[args.gpt_model](
43 block_size=latent_size ** 2,
44 cls_token_num=args.cls_token_num,
45 model_type=args.gpt_type,
46 ).to(device=device, dtype=precision)
47
48 checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
49
50 if "model" in checkpoint: # ddp
51 model_weight = checkpoint["model"]
52 elif "module" in checkpoint: # deepspeed
53 model_weight = checkpoint["module"]
54 elif "state_dict" in checkpoint:
55 model_weight = checkpoint["state_dict"]
56 else:
57 raise Exception("please check model weight")
58 gpt_model.load_state_dict(model_weight, strict=False)
59 gpt_model.eval()
60 del checkpoint
61 print(f"gpt model is loaded")
62
63 if args.compile:
64 print(f"compiling the model...")
65 gpt_model = torch.compile(
66 gpt_model,
67 mode="reduce-overhead",
68 fullgraph=True
69 ) # requires PyTorch 2.0 (optional)
70 else:
71 print(f"no need to compile model in demo")
72
73 assert os.path.exists(args.t5_path)
74 t5_model = T5Embedder(
75 device=device,
76 local_cache=True,
77 cache_dir=args.t5_path,

Callers 1

sample_t2i.pyFile · 0.70

Calls 6

get_text_embeddingsMethod · 0.95
T5EmbedderClass · 0.90
generateFunction · 0.90
printFunction · 0.85
loadMethod · 0.80
decode_codeMethod · 0.45

Tested by

no test coverage detected