MCPcopy
hub / github.com/FoundationVision/LlamaGen / main

Function main

tokenizer/vae/reconstruction_vae_ddp.py:81–196  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

79
80
81def main(args):
82 # Setup PyTorch:
83 assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
84 torch.set_grad_enabled(False)
85
86 # Setup DDP:
87 dist.init_process_group("nccl")
88 rank = dist.get_rank()
89 device = rank % torch.cuda.device_count()
90 seed = args.global_seed * dist.get_world_size() + rank
91 torch.manual_seed(seed)
92 torch.cuda.set_device(device)
93 print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
94
95 # load vae
96 vae = AutoencoderKL.from_pretrained(f"stabilityai/{args.vae}").to(device)
97
98 # Create folder to save samples:
99 folder_name = f"stabilityai-{args.vae}-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}"
100 sample_folder_dir = f"{args.sample_dir}/{folder_name}"
101 if rank == 0:
102 os.makedirs(sample_folder_dir, exist_ok=True)
103 print(f"Saving .png samples at {sample_folder_dir}")
104 dist.barrier()
105
106 # Setup data:
107 transform = transforms.Compose([
108 transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
109 transforms.ToTensor(),
110 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
111 ])
112 if args.dataset == 'imagenet':
113 dataset = ImageFolder(args.data_path, transform=transform)
114 num_fid_samples = 50000
115 elif args.dataset == 'coco':
116 dataset = SingleFolderDataset(args.data_path, transform=transform)
117 num_fid_samples = 5000
118 else:
119 raise Exception("please check dataset")
120
121 sampler = DistributedSampler(
122 dataset,
123 num_replicas=dist.get_world_size(),
124 rank=rank,
125 shuffle=False,
126 seed=args.global_seed
127 )
128 loader = DataLoader(
129 dataset,
130 batch_size=args.per_proc_batch_size,
131 shuffle=False,
132 sampler=sampler,
133 num_workers=args.num_workers,
134 pin_memory=True,
135 drop_last=False
136 )
137
138 # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:

Callers 1

Calls 8

printFunction · 0.85
from_pretrainedMethod · 0.80
sampleMethod · 0.80
center_crop_arrFunction · 0.70
SingleFolderDatasetClass · 0.70
encodeMethod · 0.45
decodeMethod · 0.45

Tested by

no test coverage detected