(self, batch, cond)
| 178 | self.log_step() |
| 179 | |
| 180 | def forward_backward(self, batch, cond): |
| 181 | self.mp_trainer.zero_grad() |
| 182 | for i in range(0, batch.shape[0], self.microbatch): |
| 183 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) |
| 184 | micro_cond = { |
| 185 | k: v[i : i + self.microbatch].to(dist_util.dev()) |
| 186 | for k, v in cond.items() |
| 187 | } |
| 188 | last_batch = (i + self.microbatch) >= batch.shape[0] |
| 189 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) |
| 190 | |
| 191 | compute_losses = functools.partial( |
| 192 | self.diffusion.training_losses, |
| 193 | self.ddp_model, |
| 194 | micro, |
| 195 | t, |
| 196 | model_kwargs=micro_cond, |
| 197 | ) |
| 198 | |
| 199 | if last_batch or not self.use_ddp: |
| 200 | losses = compute_losses() |
| 201 | else: |
| 202 | with self.ddp_model.no_sync(): |
| 203 | losses = compute_losses() |
| 204 | |
| 205 | if isinstance(self.schedule_sampler, LossAwareSampler): |
| 206 | self.schedule_sampler.update_with_local_losses( |
| 207 | t, losses["loss"].detach() |
| 208 | ) |
| 209 | |
| 210 | loss = (losses["loss"] * weights).mean() |
| 211 | log_loss_dict( |
| 212 | self.diffusion, t, {k: v * weights for k, v in losses.items()} |
| 213 | ) |
| 214 | self.mp_trainer.backward(loss) |
| 215 | |
| 216 | def _update_ema(self): |
| 217 | for rate, params in zip(self.ema_rate, self.ema_params): |
no test coverage detected