MCPcopy Index your code
hub / github.com/openai/guided-diffusion / forward_backward

Method forward_backward

guided_diffusion/train_util.py:180–214  ·  view source on GitHub ↗
(self, batch, cond)

Source from the content-addressed store, hash-verified

178 self.log_step()
179
180 def forward_backward(self, batch, cond):
181 self.mp_trainer.zero_grad()
182 for i in range(0, batch.shape[0], self.microbatch):
183 micro = batch[i : i + self.microbatch].to(dist_util.dev())
184 micro_cond = {
185 k: v[i : i + self.microbatch].to(dist_util.dev())
186 for k, v in cond.items()
187 }
188 last_batch = (i + self.microbatch) >= batch.shape[0]
189 t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
190
191 compute_losses = functools.partial(
192 self.diffusion.training_losses,
193 self.ddp_model,
194 micro,
195 t,
196 model_kwargs=micro_cond,
197 )
198
199 if last_batch or not self.use_ddp:
200 losses = compute_losses()
201 else:
202 with self.ddp_model.no_sync():
203 losses = compute_losses()
204
205 if isinstance(self.schedule_sampler, LossAwareSampler):
206 self.schedule_sampler.update_with_local_losses(
207 t, losses["loss"].detach()
208 )
209
210 loss = (losses["loss"] * weights).mean()
211 log_loss_dict(
212 self.diffusion, t, {k: v * weights for k, v in losses.items()}
213 )
214 self.mp_trainer.backward(loss)
215
216 def _update_ema(self):
217 for rate, params in zip(self.ema_rate, self.ema_params):

Callers 1

run_stepMethod · 0.95

Calls 5

log_loss_dictFunction · 0.85
zero_gradMethod · 0.80
sampleMethod · 0.80
backwardMethod · 0.45

Tested by

no test coverage detected