MCPcopy
hub / github.com/FoundationVision/LlamaGen / main

Function main

tools/check_image_codes.py:9–40  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

7
8
9def main(args):
10 # Setup PyTorch:
11 torch.manual_seed(args.seed)
12 torch.set_grad_enabled(False)
13 device = "cuda" if torch.cuda.is_available() else "cpu"
14
15 # create and load model
16 vq_model = VQ_models[args.vq_model](
17 codebook_size=args.codebook_size,
18 codebook_embed_dim=args.codebook_embed_dim)
19 vq_model.to(device)
20 vq_model.eval()
21 checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
22 vq_model.load_state_dict(checkpoint["model"])
23 del checkpoint
24
25 # load image code
26 latent_dim = args.codebook_embed_dim
27 latent_size = args.image_size // args.downsample_size
28 codes = torch.from_numpy(np.load(args.code_path)).to(device)
29 if codes.ndim == 3: # flip augmentation
30 qzshape = (codes.shape[1], latent_dim, latent_size, latent_size)
31 else:
32 qzshape = (1, latent_dim, latent_size, latent_size)
33 index_sample = codes.reshape(-1)
34 samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
35
36 # save
37 out_path = "sample_image_code.png"
38 nrow = max(4, int(codes.shape[1]//2))
39 save_image(samples, out_path, nrow=nrow, normalize=True, value_range=(-1, 1))
40 print("Reconstructed image is saved to {}".format(out_path))
41
42
43

Callers 1

Calls 3

printFunction · 0.85
loadMethod · 0.80
decode_codeMethod · 0.45

Tested by

no test coverage detected