(model, diffusion, data, num_samples, clip_denoised)
| 48 | |
| 49 | |
| 50 | def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised): |
| 51 | all_bpd = [] |
| 52 | all_metrics = {"vb": [], "mse": [], "xstart_mse": []} |
| 53 | num_complete = 0 |
| 54 | while num_complete < num_samples: |
| 55 | batch, model_kwargs = next(data) |
| 56 | batch = batch.to(dist_util.dev()) |
| 57 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} |
| 58 | minibatch_metrics = diffusion.calc_bpd_loop( |
| 59 | model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs |
| 60 | ) |
| 61 | |
| 62 | for key, term_list in all_metrics.items(): |
| 63 | terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size() |
| 64 | dist.all_reduce(terms) |
| 65 | term_list.append(terms.detach().cpu().numpy()) |
| 66 | |
| 67 | total_bpd = minibatch_metrics["total_bpd"] |
| 68 | total_bpd = total_bpd.mean() / dist.get_world_size() |
| 69 | dist.all_reduce(total_bpd) |
| 70 | all_bpd.append(total_bpd.item()) |
| 71 | num_complete += dist.get_world_size() * batch.shape[0] |
| 72 | |
| 73 | logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}") |
| 74 | |
| 75 | if dist.get_rank() == 0: |
| 76 | for name, terms in all_metrics.items(): |
| 77 | out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz") |
| 78 | logger.log(f"saving {name} terms to {out_path}") |
| 79 | np.savez(out_path, np.mean(np.stack(terms), axis=0)) |
| 80 | |
| 81 | dist.barrier() |
| 82 | logger.log("evaluation complete") |
| 83 | |
| 84 | |
| 85 | def create_argparser(): |
no test coverage detected