(diffusion, ts, losses)
| 348 | |
| 349 | |
| 350 | def log_loss_dict(diffusion, ts, losses): |
| 351 | for key, values in losses.items(): |
| 352 | logger.logkv_mean(key, values.mean().item()) |
| 353 | # Log the quantiles (four quartiles, in particular). |
| 354 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): |
| 355 | quartile = int(4 * sub_t / diffusion.num_timesteps) |
| 356 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) |
no test coverage detected