MCPcopy
hub / github.com/openai/guided-diffusion / main

Function main

scripts/classifier_sample.py:26–110  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

24
25
26def main():
27 args = create_argparser().parse_args()
28
29 dist_util.setup_dist()
30 logger.configure()
31
32 logger.log("creating model and diffusion...")
33 model, diffusion = create_model_and_diffusion(
34 **args_to_dict(args, model_and_diffusion_defaults().keys())
35 )
36 model.load_state_dict(
37 dist_util.load_state_dict(args.model_path, map_location="cpu")
38 )
39 model.to(dist_util.dev())
40 if args.use_fp16:
41 model.convert_to_fp16()
42 model.eval()
43
44 logger.log("loading classifier...")
45 classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys()))
46 classifier.load_state_dict(
47 dist_util.load_state_dict(args.classifier_path, map_location="cpu")
48 )
49 classifier.to(dist_util.dev())
50 if args.classifier_use_fp16:
51 classifier.convert_to_fp16()
52 classifier.eval()
53
54 def cond_fn(x, t, y=None):
55 assert y is not None
56 with th.enable_grad():
57 x_in = x.detach().requires_grad_(True)
58 logits = classifier(x_in, t)
59 log_probs = F.log_softmax(logits, dim=-1)
60 selected = log_probs[range(len(logits)), y.view(-1)]
61 return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale
62
63 def model_fn(x, t, y=None):
64 assert y is not None
65 return model(x, t, y if args.class_cond else None)
66
67 logger.log("sampling...")
68 all_images = []
69 all_labels = []
70 while len(all_images) * args.batch_size < args.num_samples:
71 model_kwargs = {}
72 classes = th.randint(
73 low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
74 )
75 model_kwargs["y"] = classes
76 sample_fn = (
77 diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
78 )
79 sample = sample_fn(
80 model_fn,
81 (args.batch_size, 3, args.image_size, args.image_size),
82 clip_denoised=args.clip_denoised,
83 model_kwargs=model_kwargs,

Callers 1

Calls 9

args_to_dictFunction · 0.90
create_classifierFunction · 0.90
classifier_defaultsFunction · 0.90
logMethod · 0.80
get_dirMethod · 0.80
create_argparserFunction · 0.70
convert_to_fp16Method · 0.45

Tested by

no test coverage detected