| 197 | |
| 198 | |
| 199 | class NativeScalerWithGradNormCount: |
| 200 | state_dict_key = "amp_scaler" |
| 201 | |
| 202 | def __init__(self): |
| 203 | self._scaler = torch.cuda.amp.GradScaler() |
| 204 | |
| 205 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): |
| 206 | self._scaler.scale(loss).backward(create_graph=create_graph) |
| 207 | if update_grad: |
| 208 | if clip_grad is not None: |
| 209 | assert parameters is not None |
| 210 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place |
| 211 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) |
| 212 | else: |
| 213 | self._scaler.unscale_(optimizer) |
| 214 | norm = ampscaler_get_grad_norm(parameters) |
| 215 | self._scaler.step(optimizer) |
| 216 | self._scaler.update() |
| 217 | else: |
| 218 | norm = None |
| 219 | return norm |
| 220 | |
| 221 | def state_dict(self): |
| 222 | return self._scaler.state_dict() |
| 223 | |
| 224 | def load_state_dict(self, state_dict): |
| 225 | self._scaler.load_state_dict(state_dict) |