Backward step. Args: model: Model instance or model name. scaler: TODO. loss_dict: TODO.
(self, model, scaler, loss_dict={})
| 692 | loss_dict["weight"] = weight |
| 693 | |
| 694 | def backward_step(self, model, scaler, loss_dict={}): |
| 695 | """Backward step. |
| 696 | |
| 697 | Args: |
| 698 | model: Model instance or model name. |
| 699 | scaler: TODO. |
| 700 | loss_dict: TODO. |
| 701 | """ |
| 702 | loss = loss_dict["loss"] |
| 703 | |
| 704 | if self.use_deepspeed: |
| 705 | scaled_loss = model.backward(loss) |
| 706 | else: |
| 707 | loss = loss / self.accum_grad |
| 708 | if scaler: |
| 709 | scaler.scale(loss).backward() |
| 710 | else: |
| 711 | loss.backward() |
| 712 | |
| 713 | def update_step(self, model, optim, scheduler, scaler, loss_dict=None): |
| 714 | """Update step. |