Executes the parameter update step. This includes all-reduce operations of gradients, gradient clipping, and parameter update. If successful, it also steps the learning rate scheduler and beta2 scheduler if they exist. Returns: success (bool): Whether th
(self)
| 103 | self.optimizer.zero_grad() |
| 104 | |
| 105 | def step(self): |
| 106 | """ |
| 107 | Executes the parameter update step. This includes all-reduce operations of gradients, gradient clipping, |
| 108 | and parameter update. If successful, it also steps the learning rate scheduler and beta2 scheduler |
| 109 | if they exist. |
| 110 | |
| 111 | Returns: |
| 112 | success (bool): Whether the parameter update was successful. |
| 113 | grad_norm (float): The norm of the gradient after clipping. |
| 114 | """ |
| 115 | self._all_reduce_gradients() |
| 116 | self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm) |
| 117 | |
| 118 | success, grad_norm = self.optimizer.step() |
| 119 | |
| 120 | if success and self._lr_scheduler is not None: |
| 121 | self._lr_scheduler.step() |
| 122 | |
| 123 | if success and self._beta2_scheduler is not None: |
| 124 | self._beta2_scheduler.step() |
| 125 | |
| 126 | return success, grad_norm |
| 127 | |
| 128 | def train(self): |
| 129 | """Sets the model to training mode.""" |
no test coverage detected