MCPcopy
hub / github.com/microsoft/Swin-Transformer / NativeScalerWithGradNormCount

Class NativeScalerWithGradNormCount

utils.py:199–225  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

197
198
199class NativeScalerWithGradNormCount:
200 state_dict_key = "amp_scaler"
201
202 def __init__(self):
203 self._scaler = torch.cuda.amp.GradScaler()
204
205 def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
206 self._scaler.scale(loss).backward(create_graph=create_graph)
207 if update_grad:
208 if clip_grad is not None:
209 assert parameters is not None
210 self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
211 norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
212 else:
213 self._scaler.unscale_(optimizer)
214 norm = ampscaler_get_grad_norm(parameters)
215 self._scaler.step(optimizer)
216 self._scaler.update()
217 else:
218 norm = None
219 return norm
220
221 def state_dict(self):
222 return self._scaler.state_dict()
223
224 def load_state_dict(self, state_dict):
225 self._scaler.load_state_dict(state_dict)

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected