MCPcopy
hub / github.com/FoundationVision/LlamaGen / main

Function main

tokenizer/tokenizer_image/reconstruction_vq_ddp.py:43–179  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

41
42
43def main(args):
44 # Setup PyTorch:
45 assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
46 torch.set_grad_enabled(False)
47
48 # Setup DDP:
49 dist.init_process_group("nccl")
50 rank = dist.get_rank()
51 device = rank % torch.cuda.device_count()
52 seed = args.global_seed * dist.get_world_size() + rank
53 torch.manual_seed(seed)
54 torch.cuda.set_device(device)
55 print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
56
57 # create and load model
58 vq_model = VQ_models[args.vq_model](
59 codebook_size=args.codebook_size,
60 codebook_embed_dim=args.codebook_embed_dim)
61 vq_model.to(device)
62 vq_model.eval()
63 checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
64 if "ema" in checkpoint: # ema
65 model_weight = checkpoint["ema"]
66 elif "model" in checkpoint: # ddp
67 model_weight = checkpoint["model"]
68 elif "state_dict" in checkpoint:
69 model_weight = checkpoint["state_dict"]
70 else:
71 raise Exception("please check model weight")
72 vq_model.load_state_dict(model_weight)
73 del checkpoint
74
75 # Create folder to save samples:
76 folder_name = (f"{args.vq_model}-{args.dataset}-size-{args.image_size}-size-{args.image_size_eval}"
77 f"-codebook-size-{args.codebook_size}-dim-{args.codebook_embed_dim}-seed-{args.global_seed}")
78 sample_folder_dir = f"{args.sample_dir}/{folder_name}"
79 if rank == 0:
80 os.makedirs(sample_folder_dir, exist_ok=True)
81 print(f"Saving .png samples at {sample_folder_dir}")
82 dist.barrier()
83
84 # Setup data:
85 transform = transforms.Compose([
86 transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
87 transforms.ToTensor(),
88 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
89 ])
90
91 if args.dataset == 'imagenet':
92 dataset = build_dataset(args, transform=transform)
93 num_fid_samples = 50000
94 elif args.dataset == 'coco':
95 dataset = build_dataset(args, transform=transform)
96 num_fid_samples = 5000
97 else:
98 raise Exception("please check dataset")
99
100 sampler = DistributedSampler(

Callers 1

Calls 7

center_crop_arrFunction · 0.90
build_datasetFunction · 0.90
printFunction · 0.85
loadMethod · 0.80
encodeMethod · 0.45
decode_codeMethod · 0.45

Tested by

no test coverage detected