(self, batch, batch_idx)
| 273 | return out |
| 274 | |
| 275 | def training_step(self, batch, batch_idx): |
| 276 | if isinstance(batch, list): |
| 277 | train_batch = batch[0] |
| 278 | train2_batch = batch[1] |
| 279 | loss_train, loss_dict = self.shared_step(train_batch) |
| 280 | loss_train2, _ = self.shared_step(train2_batch) |
| 281 | loss = loss_train + loss_train2 |
| 282 | else: |
| 283 | train_batch = batch |
| 284 | loss, loss_dict = self.shared_step(train_batch) |
| 285 | |
| 286 | self.log_dict(loss_dict, prog_bar=True, |
| 287 | logger=True, on_step=True, on_epoch=True) |
| 288 | |
| 289 | self.log("global_step", self.global_step, |
| 290 | prog_bar=True, logger=True, on_step=True, on_epoch=False) |
| 291 | |
| 292 | if self.use_scheduler: |
| 293 | lr = self.optimizers().param_groups[0]['lr'] |
| 294 | self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) |
| 295 | |
| 296 | return loss |
| 297 | |
| 298 | def shared_step(self, batch, **kwargs): |
| 299 | x, c, mask = self.get_input_withmask(batch, **kwargs) |
nothing calls this directly
no test coverage detected