MCPcopy Index your code
hub / github.com/openai/improved-diffusion / main

Function main

scripts/image_sample.py:23–88  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

21
22
23def main():
24 args = create_argparser().parse_args()
25
26 dist_util.setup_dist()
27 logger.configure()
28
29 logger.log("creating model and diffusion...")
30 model, diffusion = create_model_and_diffusion(
31 **args_to_dict(args, model_and_diffusion_defaults().keys())
32 )
33 model.load_state_dict(
34 dist_util.load_state_dict(args.model_path, map_location="cpu")
35 )
36 model.to(dist_util.dev())
37 model.eval()
38
39 logger.log("sampling...")
40 all_images = []
41 all_labels = []
42 while len(all_images) * args.batch_size < args.num_samples:
43 model_kwargs = {}
44 if args.class_cond:
45 classes = th.randint(
46 low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
47 )
48 model_kwargs["y"] = classes
49 sample_fn = (
50 diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
51 )
52 sample = sample_fn(
53 model,
54 (args.batch_size, 3, args.image_size, args.image_size),
55 clip_denoised=args.clip_denoised,
56 model_kwargs=model_kwargs,
57 )
58 sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
59 sample = sample.permute(0, 2, 3, 1)
60 sample = sample.contiguous()
61
62 gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
63 dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
64 all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
65 if args.class_cond:
66 gathered_labels = [
67 th.zeros_like(classes) for _ in range(dist.get_world_size())
68 ]
69 dist.all_gather(gathered_labels, classes)
70 all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
71 logger.log(f"created {len(all_images) * args.batch_size} samples")
72
73 arr = np.concatenate(all_images, axis=0)
74 arr = arr[: args.num_samples]
75 if args.class_cond:
76 label_arr = np.concatenate(all_labels, axis=0)
77 label_arr = label_arr[: args.num_samples]
78 if dist.get_rank() == 0:
79 shape_str = "x".join([str(x) for x in arr.shape])
80 out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")

Callers 1

image_sample.pyFile · 0.70

Calls 6

args_to_dictFunction · 0.90
logMethod · 0.80
get_dirMethod · 0.80
create_argparserFunction · 0.70

Tested by

no test coverage detected