Forward step. Args: model: Model instance or model name. batch: TODO. loss_dict: TODO.
(self, model, batch, loss_dict={})
| 674 | self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size |
| 675 | |
| 676 | def forward_step(self, model, batch, loss_dict={}): |
| 677 | """Forward step. |
| 678 | |
| 679 | Args: |
| 680 | model: Model instance or model name. |
| 681 | batch: TODO. |
| 682 | loss_dict: TODO. |
| 683 | """ |
| 684 | with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed): |
| 685 | retval = model(**batch) |
| 686 | |
| 687 | loss, stats, weight = retval |
| 688 | stats = {k: v for k, v in stats.items() if v is not None} |
| 689 | |
| 690 | loss_dict["loss"] = loss |
| 691 | loss_dict["stats"] = stats |
| 692 | loss_dict["weight"] = weight |
| 693 | |
| 694 | def backward_step(self, model, scaler, loss_dict={}): |
| 695 | """Backward step. |
no test coverage detected