MCPcopy
hub / github.com/openai/guided-diffusion / main

Function main

scripts/image_sample.py:23–90  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

21
22
23def main():
24 args = create_argparser().parse_args()
25
26 dist_util.setup_dist()
27 logger.configure()
28
29 logger.log("creating model and diffusion...")
30 model, diffusion = create_model_and_diffusion(
31 **args_to_dict(args, model_and_diffusion_defaults().keys())
32 )
33 model.load_state_dict(
34 dist_util.load_state_dict(args.model_path, map_location="cpu")
35 )
36 model.to(dist_util.dev())
37 if args.use_fp16:
38 model.convert_to_fp16()
39 model.eval()
40
41 logger.log("sampling...")
42 all_images = []
43 all_labels = []
44 while len(all_images) * args.batch_size < args.num_samples:
45 model_kwargs = {}
46 if args.class_cond:
47 classes = th.randint(
48 low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
49 )
50 model_kwargs["y"] = classes
51 sample_fn = (
52 diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
53 )
54 sample = sample_fn(
55 model,
56 (args.batch_size, 3, args.image_size, args.image_size),
57 clip_denoised=args.clip_denoised,
58 model_kwargs=model_kwargs,
59 )
60 sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
61 sample = sample.permute(0, 2, 3, 1)
62 sample = sample.contiguous()
63
64 gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
65 dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
66 all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
67 if args.class_cond:
68 gathered_labels = [
69 th.zeros_like(classes) for _ in range(dist.get_world_size())
70 ]
71 dist.all_gather(gathered_labels, classes)
72 all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
73 logger.log(f"created {len(all_images) * args.batch_size} samples")
74
75 arr = np.concatenate(all_images, axis=0)
76 arr = arr[: args.num_samples]
77 if args.class_cond:
78 label_arr = np.concatenate(all_labels, axis=0)
79 label_arr = label_arr[: args.num_samples]
80 if dist.get_rank() == 0:

Callers 1

image_sample.pyFile · 0.70

Calls 7

args_to_dictFunction · 0.90
logMethod · 0.80
get_dirMethod · 0.80
create_argparserFunction · 0.70
convert_to_fp16Method · 0.45

Tested by

no test coverage detected