MCPcopy
hub / github.com/openai/guided-diffusion / main

Function main

scripts/classifier_train.py:28–167  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

26
27
28def main():
29 args = create_argparser().parse_args()
30
31 dist_util.setup_dist()
32 logger.configure()
33
34 logger.log("creating model and diffusion...")
35 model, diffusion = create_classifier_and_diffusion(
36 **args_to_dict(args, classifier_and_diffusion_defaults().keys())
37 )
38 model.to(dist_util.dev())
39 if args.noised:
40 schedule_sampler = create_named_schedule_sampler(
41 args.schedule_sampler, diffusion
42 )
43
44 resume_step = 0
45 if args.resume_checkpoint:
46 resume_step = parse_resume_step_from_filename(args.resume_checkpoint)
47 if dist.get_rank() == 0:
48 logger.log(
49 f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step"
50 )
51 model.load_state_dict(
52 dist_util.load_state_dict(
53 args.resume_checkpoint, map_location=dist_util.dev()
54 )
55 )
56
57 # Needed for creating correct EMAs and fp16 parameters.
58 dist_util.sync_params(model.parameters())
59
60 mp_trainer = MixedPrecisionTrainer(
61 model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0
62 )
63
64 model = DDP(
65 model,
66 device_ids=[dist_util.dev()],
67 output_device=dist_util.dev(),
68 broadcast_buffers=False,
69 bucket_cap_mb=128,
70 find_unused_parameters=False,
71 )
72
73 logger.log("creating data loader...")
74 data = load_data(
75 data_dir=args.data_dir,
76 batch_size=args.batch_size,
77 image_size=args.image_size,
78 class_cond=True,
79 random_crop=True,
80 )
81 if args.val_data_dir:
82 val_data = load_data(
83 data_dir=args.val_data_dir,
84 batch_size=args.batch_size,
85 image_size=args.image_size,

Callers 1

Calls 15

optimizeMethod · 0.95
args_to_dictFunction · 0.90
load_dataFunction · 0.90
set_annealed_lrFunction · 0.85
forward_backward_logFunction · 0.85
save_modelFunction · 0.85
logMethod · 0.80

Tested by

no test coverage detected