MCPcopy
hub / github.com/mosaicml/composer / ClosureGradScaler

Class ClosureGradScaler

composer/trainer/_scaler.py:16–146  ·  view source on GitHub ↗

ClosureGradScaler allows for gradient scaling during with closures. We use closures with optimizers (see `here `__) during training in order to support certain algorithms like :class:`~composer.algorithms.SAM`. This class allows us to perform

Source from the content-addressed store, hash-verified

14
15
16class ClosureGradScaler(GradScaler):
17 """ClosureGradScaler allows for gradient scaling during with closures.
18
19 We use closures with optimizers (see `here <https://pytorch.org/docs/stable/optim.html>`__)
20 during training in order to support certain algorithms like
21 :class:`~composer.algorithms.SAM`. This class allows us to perform gradient
22 scaling (see `here <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`__)
23 along with the use of closures during training.
24
25 Args:
26 ddp_reduce_scalar_and (Callable[[bool], bool]): A function that performs a
27 ddp reduction with an `and` operation. Used to determine whether
28 or not to continue computing an optimizer&#x27;s `step` based on the presence
29 of `inf/nan` in the gradients.
30 ddp_reduce_tensor_sum (Callable[[Tensor], Tensor]): A function that performs
31 a ddp reduction across tensors with a `sum` operation. Used to aggregate
32 `inf/nan` information stored in tensors across devices.
33 """
34
35 def _force_scaler_ready(self, optimizer: Optimizer):
36 optimizer_state = self._per_optimizer_states[id(optimizer)]
37 optimizer_state['stage'] = OptState.READY
38
39 def _empty_all_grads(self, optimizer):
40 for group in optimizer.param_groups:
41 for param in group['params']:
42 if param.grad is not None:
43 param.grad = None
44
45 def _unscale_grads_and_continue(self, optimizer: Optimizer):
46 if (not self._enabled):
47 return True
48 self._check_scale_growth_tracker('step')
49 optimizer_state = self._per_optimizer_states[id(optimizer)]
50
51 if optimizer_state['stage'] is OptState.STEPPED:
52 raise RuntimeError('step() has already been called since the last update().')
53
54 if optimizer_state['stage'] is OptState.READY:
55 self.unscale_(optimizer)
56 inf_detected = sum(v.item() for v in optimizer_state['found_inf_per_device'].values())
57 return not inf_detected
58
59 def step(self, optimizer: Optimizer, *args, **kwargs):
60 """Step the optimizer with amp.
61
62 Always called before the optimizer step. Checks if the optimizer can handle AMP closures (currently only
63 Composer&#x27;s SAM optimizer) If so, it passes an AMP-modified closure to the optimizer.
64 """
65 closure = kwargs['closure']
66
67 def _amp_closure(**kwargs):
68 self._force_scaler_ready(optimizer)
69 self._empty_all_grads(optimizer)
70
71 retval: float = closure(**kwargs)
72
73 should_continue = self._unscale_grads_and_continue(optimizer)

Callers 3

__init__Method · 0.90
fitMethod · 0.90
applyMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected