(args)
| 79 | |
| 80 | |
| 81 | def main(args): |
| 82 | # Setup PyTorch: |
| 83 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" |
| 84 | torch.set_grad_enabled(False) |
| 85 | |
| 86 | # Setup env |
| 87 | dist.init_process_group("nccl") |
| 88 | rank = dist.get_rank() |
| 89 | device = rank % torch.cuda.device_count() |
| 90 | seed = args.global_seed * dist.get_world_size() + rank |
| 91 | torch.manual_seed(seed) |
| 92 | torch.cuda.set_device(device) |
| 93 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") |
| 94 | |
| 95 | # create and load model |
| 96 | vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to("cuda:{}".format(device)) |
| 97 | |
| 98 | # Create folder to save samples: |
| 99 | folder_name = f"openai-consistencydecoder-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}" |
| 100 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" |
| 101 | if rank == 0: |
| 102 | os.makedirs(sample_folder_dir, exist_ok=True) |
| 103 | print(f"Saving .png samples at {sample_folder_dir}") |
| 104 | dist.barrier() |
| 105 | |
| 106 | # Setup data: |
| 107 | transform = transforms.Compose([ |
| 108 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), |
| 109 | transforms.ToTensor(), |
| 110 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) |
| 111 | ]) |
| 112 | if args.dataset == 'imagenet': |
| 113 | dataset = ImageFolder(args.data_path, transform=transform) |
| 114 | num_fid_samples = 50000 |
| 115 | elif args.dataset == 'coco': |
| 116 | dataset = SingleFolderDataset(args.data_path, transform=transform) |
| 117 | num_fid_samples = 5000 |
| 118 | else: |
| 119 | raise Exception("please check dataset") |
| 120 | sampler = DistributedSampler( |
| 121 | dataset, |
| 122 | num_replicas=dist.get_world_size(), |
| 123 | rank=rank, |
| 124 | shuffle=False, |
| 125 | seed=args.global_seed |
| 126 | ) |
| 127 | loader = DataLoader( |
| 128 | dataset, |
| 129 | batch_size=args.per_proc_batch_size, |
| 130 | shuffle=False, |
| 131 | sampler=sampler, |
| 132 | num_workers=args.num_workers, |
| 133 | pin_memory=True, |
| 134 | drop_last=False |
| 135 | ) |
| 136 | |
| 137 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: |
| 138 | n = args.per_proc_batch_size |
no test coverage detected