(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True)
| 203 | self._scaler = torch.cuda.amp.GradScaler() |
| 204 | |
| 205 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): |
| 206 | self._scaler.scale(loss).backward(create_graph=create_graph) |
| 207 | if update_grad: |
| 208 | if clip_grad is not None: |
| 209 | assert parameters is not None |
| 210 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place |
| 211 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) |
| 212 | else: |
| 213 | self._scaler.unscale_(optimizer) |
| 214 | norm = ampscaler_get_grad_norm(parameters) |
| 215 | self._scaler.step(optimizer) |
| 216 | self._scaler.update() |
| 217 | else: |
| 218 | norm = None |
| 219 | return norm |
| 220 | |
| 221 | def state_dict(self): |
| 222 | return self._scaler.state_dict() |
nothing calls this directly
no test coverage detected