(diffusion, ts, losses)
| 293 | |
| 294 | |
| 295 | def log_loss_dict(diffusion, ts, losses): |
| 296 | for key, values in losses.items(): |
| 297 | logger.logkv_mean(key, values.mean().item()) |
| 298 | # Log the quantiles (four quartiles, in particular). |
| 299 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): |
| 300 | quartile = int(4 * sub_t / diffusion.num_timesteps) |
| 301 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) |
no test coverage detected