MCPcopy
hub / github.com/FoundationVision/LlamaGen / main

Function main

tokenizer/consistencydecoder/reconstruction_cd_ddp.py:81–195  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

79
80
81def main(args):
82 # Setup PyTorch:
83 assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
84 torch.set_grad_enabled(False)
85
86 # Setup env
87 dist.init_process_group("nccl")
88 rank = dist.get_rank()
89 device = rank % torch.cuda.device_count()
90 seed = args.global_seed * dist.get_world_size() + rank
91 torch.manual_seed(seed)
92 torch.cuda.set_device(device)
93 print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
94
95 # create and load model
96 vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to("cuda:{}".format(device))
97
98 # Create folder to save samples:
99 folder_name = f"openai-consistencydecoder-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}"
100 sample_folder_dir = f"{args.sample_dir}/{folder_name}"
101 if rank == 0:
102 os.makedirs(sample_folder_dir, exist_ok=True)
103 print(f"Saving .png samples at {sample_folder_dir}")
104 dist.barrier()
105
106 # Setup data:
107 transform = transforms.Compose([
108 transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
109 transforms.ToTensor(),
110 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
111 ])
112 if args.dataset == 'imagenet':
113 dataset = ImageFolder(args.data_path, transform=transform)
114 num_fid_samples = 50000
115 elif args.dataset == 'coco':
116 dataset = SingleFolderDataset(args.data_path, transform=transform)
117 num_fid_samples = 5000
118 else:
119 raise Exception("please check dataset")
120 sampler = DistributedSampler(
121 dataset,
122 num_replicas=dist.get_world_size(),
123 rank=rank,
124 shuffle=False,
125 seed=args.global_seed
126 )
127 loader = DataLoader(
128 dataset,
129 batch_size=args.per_proc_batch_size,
130 shuffle=False,
131 sampler=sampler,
132 num_workers=args.num_workers,
133 pin_memory=True,
134 drop_last=False
135 )
136
137 # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
138 n = args.per_proc_batch_size

Callers 1

Calls 8

printFunction · 0.85
from_pretrainedMethod · 0.80
sampleMethod · 0.80
center_crop_arrFunction · 0.70
SingleFolderDatasetClass · 0.70
encodeMethod · 0.45
decodeMethod · 0.45

Tested by

no test coverage detected