MCPcopy Index your code
hub / github.com/adobe-research/custom-diffusion / training_step

Method training_step

src/model.py:275–296  ·  view source on GitHub ↗
(self, batch, batch_idx)

Source from the content-addressed store, hash-verified

273 return out
274
275 def training_step(self, batch, batch_idx):
276 if isinstance(batch, list):
277 train_batch = batch[0]
278 train2_batch = batch[1]
279 loss_train, loss_dict = self.shared_step(train_batch)
280 loss_train2, _ = self.shared_step(train2_batch)
281 loss = loss_train + loss_train2
282 else:
283 train_batch = batch
284 loss, loss_dict = self.shared_step(train_batch)
285
286 self.log_dict(loss_dict, prog_bar=True,
287 logger=True, on_step=True, on_epoch=True)
288
289 self.log("global_step", self.global_step,
290 prog_bar=True, logger=True, on_step=True, on_epoch=False)
291
292 if self.use_scheduler:
293 lr = self.optimizers().param_groups[0]['lr']
294 self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
295
296 return loss
297
298 def shared_step(self, batch, **kwargs):
299 x, c, mask = self.get_input_withmask(batch, **kwargs)

Callers

nothing calls this directly

Calls 1

shared_stepMethod · 0.95

Tested by

no test coverage detected