ClosureGradScaler allows for gradient scaling during with closures. We use closures with optimizers (see `here `__) during training in order to support certain algorithms like :class:`~composer.algorithms.SAM`. This class allows us to perform
| 14 | |
| 15 | |
| 16 | class ClosureGradScaler(GradScaler): |
| 17 | """ClosureGradScaler allows for gradient scaling during with closures. |
| 18 | |
| 19 | We use closures with optimizers (see `here <https://pytorch.org/docs/stable/optim.html>`__) |
| 20 | during training in order to support certain algorithms like |
| 21 | :class:`~composer.algorithms.SAM`. This class allows us to perform gradient |
| 22 | scaling (see `here <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`__) |
| 23 | along with the use of closures during training. |
| 24 | |
| 25 | Args: |
| 26 | ddp_reduce_scalar_and (Callable[[bool], bool]): A function that performs a |
| 27 | ddp reduction with an `and` operation. Used to determine whether |
| 28 | or not to continue computing an optimizer's `step` based on the presence |
| 29 | of `inf/nan` in the gradients. |
| 30 | ddp_reduce_tensor_sum (Callable[[Tensor], Tensor]): A function that performs |
| 31 | a ddp reduction across tensors with a `sum` operation. Used to aggregate |
| 32 | `inf/nan` information stored in tensors across devices. |
| 33 | """ |
| 34 | |
| 35 | def _force_scaler_ready(self, optimizer: Optimizer): |
| 36 | optimizer_state = self._per_optimizer_states[id(optimizer)] |
| 37 | optimizer_state['stage'] = OptState.READY |
| 38 | |
| 39 | def _empty_all_grads(self, optimizer): |
| 40 | for group in optimizer.param_groups: |
| 41 | for param in group['params']: |
| 42 | if param.grad is not None: |
| 43 | param.grad = None |
| 44 | |
| 45 | def _unscale_grads_and_continue(self, optimizer: Optimizer): |
| 46 | if (not self._enabled): |
| 47 | return True |
| 48 | self._check_scale_growth_tracker('step') |
| 49 | optimizer_state = self._per_optimizer_states[id(optimizer)] |
| 50 | |
| 51 | if optimizer_state['stage'] is OptState.STEPPED: |
| 52 | raise RuntimeError('step() has already been called since the last update().') |
| 53 | |
| 54 | if optimizer_state['stage'] is OptState.READY: |
| 55 | self.unscale_(optimizer) |
| 56 | inf_detected = sum(v.item() for v in optimizer_state['found_inf_per_device'].values()) |
| 57 | return not inf_detected |
| 58 | |
| 59 | def step(self, optimizer: Optimizer, *args, **kwargs): |
| 60 | """Step the optimizer with amp. |
| 61 | |
| 62 | Always called before the optimizer step. Checks if the optimizer can handle AMP closures (currently only |
| 63 | Composer's SAM optimizer) If so, it passes an AMP-modified closure to the optimizer. |
| 64 | """ |
| 65 | closure = kwargs['closure'] |
| 66 | |
| 67 | def _amp_closure(**kwargs): |
| 68 | self._force_scaler_ready(optimizer) |
| 69 | self._empty_all_grads(optimizer) |
| 70 | |
| 71 | retval: float = closure(**kwargs) |
| 72 | |
| 73 | should_continue = self._unscale_grads_and_continue(optimizer) |