MCPcopy
hub / github.com/openai/guided-diffusion / main

Function main

scripts/super_res_train.py:21–60  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

19
20
21def main():
22 args = create_argparser().parse_args()
23
24 dist_util.setup_dist()
25 logger.configure()
26
27 logger.log("creating model...")
28 model, diffusion = sr_create_model_and_diffusion(
29 **args_to_dict(args, sr_model_and_diffusion_defaults().keys())
30 )
31 model.to(dist_util.dev())
32 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
33
34 logger.log("creating data loader...")
35 data = load_superres_data(
36 args.data_dir,
37 args.batch_size,
38 large_size=args.large_size,
39 small_size=args.small_size,
40 class_cond=args.class_cond,
41 )
42
43 logger.log("training...")
44 TrainLoop(
45 model=model,
46 diffusion=diffusion,
47 data=data,
48 batch_size=args.batch_size,
49 microbatch=args.microbatch,
50 lr=args.lr,
51 ema_rate=args.ema_rate,
52 log_interval=args.log_interval,
53 save_interval=args.save_interval,
54 resume_checkpoint=args.resume_checkpoint,
55 use_fp16=args.use_fp16,
56 fp16_scale_growth=args.fp16_scale_growth,
57 schedule_sampler=schedule_sampler,
58 weight_decay=args.weight_decay,
59 lr_anneal_steps=args.lr_anneal_steps,
60 ).run_loop()
61
62
63def load_superres_data(data_dir, batch_size, large_size, small_size, class_cond=False):

Callers 1

super_res_train.pyFile · 0.70

Calls 9

args_to_dictFunction · 0.90
TrainLoopClass · 0.90
load_superres_dataFunction · 0.85
logMethod · 0.80
run_loopMethod · 0.80
create_argparserFunction · 0.70

Tested by

no test coverage detected