()
| 19 | |
| 20 | |
| 21 | def main(): |
| 22 | args = create_argparser().parse_args() |
| 23 | |
| 24 | dist_util.setup_dist() |
| 25 | logger.configure() |
| 26 | |
| 27 | logger.log("creating model and diffusion...") |
| 28 | model, diffusion = create_model_and_diffusion( |
| 29 | **args_to_dict(args, model_and_diffusion_defaults().keys()) |
| 30 | ) |
| 31 | model.load_state_dict( |
| 32 | dist_util.load_state_dict(args.model_path, map_location="cpu") |
| 33 | ) |
| 34 | model.to(dist_util.dev()) |
| 35 | model.eval() |
| 36 | |
| 37 | logger.log("creating data loader...") |
| 38 | data = load_data( |
| 39 | data_dir=args.data_dir, |
| 40 | batch_size=args.batch_size, |
| 41 | image_size=args.image_size, |
| 42 | class_cond=args.class_cond, |
| 43 | deterministic=True, |
| 44 | ) |
| 45 | |
| 46 | logger.log("evaluating...") |
| 47 | run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised) |
| 48 | |
| 49 | |
| 50 | def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised): |
no test coverage detected