MCPcopy
hub / github.com/microsoft/Cream / NativeScalerWithGradNormCount

Class NativeScalerWithGradNormCount

TinyViT/utils.py:300–327  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

298
299
300class NativeScalerWithGradNormCount:
301 state_dict_key = "amp_scaler"
302
303 def __init__(self, grad_scaler_enabled=True):
304 self._scaler = torch.cuda.amp.GradScaler(enabled=grad_scaler_enabled)
305
306 def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
307 self._scaler.scale(loss).backward(create_graph=create_graph)
308 if update_grad:
309 if clip_grad is not None and clip_grad > 0.0:
310 assert parameters is not None
311 # unscale the gradients of optimizer's assigned params in-place
312 self._scaler.unscale_(optimizer)
313 norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
314 else:
315 self._scaler.unscale_(optimizer)
316 norm = ampscaler_get_grad_norm(parameters)
317 self._scaler.step(optimizer)
318 self._scaler.update()
319 else:
320 norm = None
321 return norm
322
323 def state_dict(self):
324 return self._scaler.state_dict()
325
326 def load_state_dict(self, state_dict):
327 self._scaler.load_state_dict(state_dict)
328
329
330def is_main_process():

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected