(args)
| 41 | |
| 42 | |
| 43 | def main(args): |
| 44 | # Setup PyTorch: |
| 45 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" |
| 46 | torch.set_grad_enabled(False) |
| 47 | |
| 48 | # Setup DDP: |
| 49 | dist.init_process_group("nccl") |
| 50 | rank = dist.get_rank() |
| 51 | device = rank % torch.cuda.device_count() |
| 52 | seed = args.global_seed * dist.get_world_size() + rank |
| 53 | torch.manual_seed(seed) |
| 54 | torch.cuda.set_device(device) |
| 55 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") |
| 56 | |
| 57 | # create and load model |
| 58 | vq_model = VQ_models[args.vq_model]( |
| 59 | codebook_size=args.codebook_size, |
| 60 | codebook_embed_dim=args.codebook_embed_dim) |
| 61 | vq_model.to(device) |
| 62 | vq_model.eval() |
| 63 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu") |
| 64 | if "ema" in checkpoint: # ema |
| 65 | model_weight = checkpoint["ema"] |
| 66 | elif "model" in checkpoint: # ddp |
| 67 | model_weight = checkpoint["model"] |
| 68 | elif "state_dict" in checkpoint: |
| 69 | model_weight = checkpoint["state_dict"] |
| 70 | else: |
| 71 | raise Exception("please check model weight") |
| 72 | vq_model.load_state_dict(model_weight) |
| 73 | del checkpoint |
| 74 | |
| 75 | # Create folder to save samples: |
| 76 | folder_name = (f"{args.vq_model}-{args.dataset}-size-{args.image_size}-size-{args.image_size_eval}" |
| 77 | f"-codebook-size-{args.codebook_size}-dim-{args.codebook_embed_dim}-seed-{args.global_seed}") |
| 78 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" |
| 79 | if rank == 0: |
| 80 | os.makedirs(sample_folder_dir, exist_ok=True) |
| 81 | print(f"Saving .png samples at {sample_folder_dir}") |
| 82 | dist.barrier() |
| 83 | |
| 84 | # Setup data: |
| 85 | transform = transforms.Compose([ |
| 86 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), |
| 87 | transforms.ToTensor(), |
| 88 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) |
| 89 | ]) |
| 90 | |
| 91 | if args.dataset == 'imagenet': |
| 92 | dataset = build_dataset(args, transform=transform) |
| 93 | num_fid_samples = 50000 |
| 94 | elif args.dataset == 'coco': |
| 95 | dataset = build_dataset(args, transform=transform) |
| 96 | num_fid_samples = 5000 |
| 97 | else: |
| 98 | raise Exception("please check dataset") |
| 99 | |
| 100 | sampler = DistributedSampler( |
no test coverage detected