MCPcopy Index your code
hub / github.com/openai/improved-diffusion / forward_backward

Method forward_backward

improved_diffusion/train_util.py:188–226  ·  view source on GitHub ↗
(self, batch, cond)

Source from the content-addressed store, hash-verified

186 self.log_step()
187
188 def forward_backward(self, batch, cond):
189 zero_grad(self.model_params)
190 for i in range(0, batch.shape[0], self.microbatch):
191 micro = batch[i : i + self.microbatch].to(dist_util.dev())
192 micro_cond = {
193 k: v[i : i + self.microbatch].to(dist_util.dev())
194 for k, v in cond.items()
195 }
196 last_batch = (i + self.microbatch) >= batch.shape[0]
197 t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
198
199 compute_losses = functools.partial(
200 self.diffusion.training_losses,
201 self.ddp_model,
202 micro,
203 t,
204 model_kwargs=micro_cond,
205 )
206
207 if last_batch or not self.use_ddp:
208 losses = compute_losses()
209 else:
210 with self.ddp_model.no_sync():
211 losses = compute_losses()
212
213 if isinstance(self.schedule_sampler, LossAwareSampler):
214 self.schedule_sampler.update_with_local_losses(
215 t, losses["loss"].detach()
216 )
217
218 loss = (losses["loss"] * weights).mean()
219 log_loss_dict(
220 self.diffusion, t, {k: v * weights for k, v in losses.items()}
221 )
222 if self.use_fp16:
223 loss_scale = 2 ** self.lg_loss_scale
224 (loss * loss_scale).backward()
225 else:
226 loss.backward()
227
228 def optimize_fp16(self):
229 if any(not th.isfinite(p.grad).all() for p in self.model_params):

Callers 1

run_stepMethod · 0.95

Calls 5

zero_gradFunction · 0.85
log_loss_dictFunction · 0.85
sampleMethod · 0.80
backwardMethod · 0.80

Tested by

no test coverage detected