()
| 17 | |
| 18 | |
| 19 | def main(): |
| 20 | args = create_argparser().parse_args() |
| 21 | |
| 22 | dist_util.setup_dist() |
| 23 | logger.configure() |
| 24 | |
| 25 | logger.log("creating model and diffusion...") |
| 26 | model, diffusion = create_model_and_diffusion( |
| 27 | **args_to_dict(args, model_and_diffusion_defaults().keys()) |
| 28 | ) |
| 29 | model.to(dist_util.dev()) |
| 30 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) |
| 31 | |
| 32 | logger.log("creating data loader...") |
| 33 | data = load_data( |
| 34 | data_dir=args.data_dir, |
| 35 | batch_size=args.batch_size, |
| 36 | image_size=args.image_size, |
| 37 | class_cond=args.class_cond, |
| 38 | ) |
| 39 | |
| 40 | logger.log("training...") |
| 41 | TrainLoop( |
| 42 | model=model, |
| 43 | diffusion=diffusion, |
| 44 | data=data, |
| 45 | batch_size=args.batch_size, |
| 46 | microbatch=args.microbatch, |
| 47 | lr=args.lr, |
| 48 | ema_rate=args.ema_rate, |
| 49 | log_interval=args.log_interval, |
| 50 | save_interval=args.save_interval, |
| 51 | resume_checkpoint=args.resume_checkpoint, |
| 52 | use_fp16=args.use_fp16, |
| 53 | fp16_scale_growth=args.fp16_scale_growth, |
| 54 | schedule_sampler=schedule_sampler, |
| 55 | weight_decay=args.weight_decay, |
| 56 | lr_anneal_steps=args.lr_anneal_steps, |
| 57 | ).run_loop() |
| 58 | |
| 59 | |
| 60 | def create_argparser(): |
no test coverage detected