(args)
| 7 | |
| 8 | |
| 9 | def main(args): |
| 10 | # Setup PyTorch: |
| 11 | torch.manual_seed(args.seed) |
| 12 | torch.set_grad_enabled(False) |
| 13 | device = "cuda" if torch.cuda.is_available() else "cpu" |
| 14 | |
| 15 | # create and load model |
| 16 | vq_model = VQ_models[args.vq_model]( |
| 17 | codebook_size=args.codebook_size, |
| 18 | codebook_embed_dim=args.codebook_embed_dim) |
| 19 | vq_model.to(device) |
| 20 | vq_model.eval() |
| 21 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu") |
| 22 | vq_model.load_state_dict(checkpoint["model"]) |
| 23 | del checkpoint |
| 24 | |
| 25 | # load image code |
| 26 | latent_dim = args.codebook_embed_dim |
| 27 | latent_size = args.image_size // args.downsample_size |
| 28 | codes = torch.from_numpy(np.load(args.code_path)).to(device) |
| 29 | if codes.ndim == 3: # flip augmentation |
| 30 | qzshape = (codes.shape[1], latent_dim, latent_size, latent_size) |
| 31 | else: |
| 32 | qzshape = (1, latent_dim, latent_size, latent_size) |
| 33 | index_sample = codes.reshape(-1) |
| 34 | samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1] |
| 35 | |
| 36 | # save |
| 37 | out_path = "sample_image_code.png" |
| 38 | nrow = max(4, int(codes.shape[1]//2)) |
| 39 | save_image(samples, out_path, nrow=nrow, normalize=True, value_range=(-1, 1)) |
| 40 | print("Reconstructed image is saved to {}".format(out_path)) |
| 41 | |
| 42 | |
| 43 |
no test coverage detected