(self, batch, cond)
| 186 | self.log_step() |
| 187 | |
| 188 | def forward_backward(self, batch, cond): |
| 189 | zero_grad(self.model_params) |
| 190 | for i in range(0, batch.shape[0], self.microbatch): |
| 191 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) |
| 192 | micro_cond = { |
| 193 | k: v[i : i + self.microbatch].to(dist_util.dev()) |
| 194 | for k, v in cond.items() |
| 195 | } |
| 196 | last_batch = (i + self.microbatch) >= batch.shape[0] |
| 197 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) |
| 198 | |
| 199 | compute_losses = functools.partial( |
| 200 | self.diffusion.training_losses, |
| 201 | self.ddp_model, |
| 202 | micro, |
| 203 | t, |
| 204 | model_kwargs=micro_cond, |
| 205 | ) |
| 206 | |
| 207 | if last_batch or not self.use_ddp: |
| 208 | losses = compute_losses() |
| 209 | else: |
| 210 | with self.ddp_model.no_sync(): |
| 211 | losses = compute_losses() |
| 212 | |
| 213 | if isinstance(self.schedule_sampler, LossAwareSampler): |
| 214 | self.schedule_sampler.update_with_local_losses( |
| 215 | t, losses["loss"].detach() |
| 216 | ) |
| 217 | |
| 218 | loss = (losses["loss"] * weights).mean() |
| 219 | log_loss_dict( |
| 220 | self.diffusion, t, {k: v * weights for k, v in losses.items()} |
| 221 | ) |
| 222 | if self.use_fp16: |
| 223 | loss_scale = 2 ** self.lg_loss_scale |
| 224 | (loss * loss_scale).backward() |
| 225 | else: |
| 226 | loss.backward() |
| 227 | |
| 228 | def optimize_fp16(self): |
| 229 | if any(not th.isfinite(p.grad).all() for p in self.model_params): |
no test coverage detected