()
| 26 | |
| 27 | |
| 28 | def main(): |
| 29 | args = create_argparser().parse_args() |
| 30 | |
| 31 | dist_util.setup_dist() |
| 32 | logger.configure() |
| 33 | |
| 34 | logger.log("creating model and diffusion...") |
| 35 | model, diffusion = create_classifier_and_diffusion( |
| 36 | **args_to_dict(args, classifier_and_diffusion_defaults().keys()) |
| 37 | ) |
| 38 | model.to(dist_util.dev()) |
| 39 | if args.noised: |
| 40 | schedule_sampler = create_named_schedule_sampler( |
| 41 | args.schedule_sampler, diffusion |
| 42 | ) |
| 43 | |
| 44 | resume_step = 0 |
| 45 | if args.resume_checkpoint: |
| 46 | resume_step = parse_resume_step_from_filename(args.resume_checkpoint) |
| 47 | if dist.get_rank() == 0: |
| 48 | logger.log( |
| 49 | f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step" |
| 50 | ) |
| 51 | model.load_state_dict( |
| 52 | dist_util.load_state_dict( |
| 53 | args.resume_checkpoint, map_location=dist_util.dev() |
| 54 | ) |
| 55 | ) |
| 56 | |
| 57 | # Needed for creating correct EMAs and fp16 parameters. |
| 58 | dist_util.sync_params(model.parameters()) |
| 59 | |
| 60 | mp_trainer = MixedPrecisionTrainer( |
| 61 | model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0 |
| 62 | ) |
| 63 | |
| 64 | model = DDP( |
| 65 | model, |
| 66 | device_ids=[dist_util.dev()], |
| 67 | output_device=dist_util.dev(), |
| 68 | broadcast_buffers=False, |
| 69 | bucket_cap_mb=128, |
| 70 | find_unused_parameters=False, |
| 71 | ) |
| 72 | |
| 73 | logger.log("creating data loader...") |
| 74 | data = load_data( |
| 75 | data_dir=args.data_dir, |
| 76 | batch_size=args.batch_size, |
| 77 | image_size=args.image_size, |
| 78 | class_cond=True, |
| 79 | random_crop=True, |
| 80 | ) |
| 81 | if args.val_data_dir: |
| 82 | val_data = load_data( |
| 83 | data_dir=args.val_data_dir, |
| 84 | batch_size=args.batch_size, |
| 85 | image_size=args.image_size, |
no test coverage detected