| 298 | |
| 299 | |
| 300 | class NativeScalerWithGradNormCount: |
| 301 | state_dict_key = "amp_scaler" |
| 302 | |
| 303 | def __init__(self, grad_scaler_enabled=True): |
| 304 | self._scaler = torch.cuda.amp.GradScaler(enabled=grad_scaler_enabled) |
| 305 | |
| 306 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): |
| 307 | self._scaler.scale(loss).backward(create_graph=create_graph) |
| 308 | if update_grad: |
| 309 | if clip_grad is not None and clip_grad > 0.0: |
| 310 | assert parameters is not None |
| 311 | # unscale the gradients of optimizer's assigned params in-place |
| 312 | self._scaler.unscale_(optimizer) |
| 313 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) |
| 314 | else: |
| 315 | self._scaler.unscale_(optimizer) |
| 316 | norm = ampscaler_get_grad_norm(parameters) |
| 317 | self._scaler.step(optimizer) |
| 318 | self._scaler.update() |
| 319 | else: |
| 320 | norm = None |
| 321 | return norm |
| 322 | |
| 323 | def state_dict(self): |
| 324 | return self._scaler.state_dict() |
| 325 | |
| 326 | def load_state_dict(self, state_dict): |
| 327 | self._scaler.load_state_dict(state_dict) |
| 328 | |
| 329 | |
| 330 | def is_main_process(): |