MCPcopy
hub / github.com/openai/guided-diffusion / main

Function main

scripts/image_train.py:19–57  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

17
18
19def main():
20 args = create_argparser().parse_args()
21
22 dist_util.setup_dist()
23 logger.configure()
24
25 logger.log("creating model and diffusion...")
26 model, diffusion = create_model_and_diffusion(
27 **args_to_dict(args, model_and_diffusion_defaults().keys())
28 )
29 model.to(dist_util.dev())
30 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
31
32 logger.log("creating data loader...")
33 data = load_data(
34 data_dir=args.data_dir,
35 batch_size=args.batch_size,
36 image_size=args.image_size,
37 class_cond=args.class_cond,
38 )
39
40 logger.log("training...")
41 TrainLoop(
42 model=model,
43 diffusion=diffusion,
44 data=data,
45 batch_size=args.batch_size,
46 microbatch=args.microbatch,
47 lr=args.lr,
48 ema_rate=args.ema_rate,
49 log_interval=args.log_interval,
50 save_interval=args.save_interval,
51 resume_checkpoint=args.resume_checkpoint,
52 use_fp16=args.use_fp16,
53 fp16_scale_growth=args.fp16_scale_growth,
54 schedule_sampler=schedule_sampler,
55 weight_decay=args.weight_decay,
56 lr_anneal_steps=args.lr_anneal_steps,
57 ).run_loop()
58
59
60def create_argparser():

Callers 1

image_train.pyFile · 0.70

Calls 9

args_to_dictFunction · 0.90
load_dataFunction · 0.90
TrainLoopClass · 0.90
logMethod · 0.80
run_loopMethod · 0.80
create_argparserFunction · 0.70

Tested by

no test coverage detected