()
| 19 | |
| 20 | |
| 21 | def main(): |
| 22 | args = create_argparser().parse_args() |
| 23 | |
| 24 | dist_util.setup_dist() |
| 25 | logger.configure() |
| 26 | |
| 27 | logger.log("creating model...") |
| 28 | model, diffusion = sr_create_model_and_diffusion( |
| 29 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) |
| 30 | ) |
| 31 | model.to(dist_util.dev()) |
| 32 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) |
| 33 | |
| 34 | logger.log("creating data loader...") |
| 35 | data = load_superres_data( |
| 36 | args.data_dir, |
| 37 | args.batch_size, |
| 38 | large_size=args.large_size, |
| 39 | small_size=args.small_size, |
| 40 | class_cond=args.class_cond, |
| 41 | ) |
| 42 | |
| 43 | logger.log("training...") |
| 44 | TrainLoop( |
| 45 | model=model, |
| 46 | diffusion=diffusion, |
| 47 | data=data, |
| 48 | batch_size=args.batch_size, |
| 49 | microbatch=args.microbatch, |
| 50 | lr=args.lr, |
| 51 | ema_rate=args.ema_rate, |
| 52 | log_interval=args.log_interval, |
| 53 | save_interval=args.save_interval, |
| 54 | resume_checkpoint=args.resume_checkpoint, |
| 55 | use_fp16=args.use_fp16, |
| 56 | fp16_scale_growth=args.fp16_scale_growth, |
| 57 | schedule_sampler=schedule_sampler, |
| 58 | weight_decay=args.weight_decay, |
| 59 | lr_anneal_steps=args.lr_anneal_steps, |
| 60 | ).run_loop() |
| 61 | |
| 62 | |
| 63 | def load_superres_data(data_dir, batch_size, large_size, small_size, class_cond=False): |
no test coverage detected