MCPcopy Index your code
hub / github.com/microsoft/Swin-Transformer / __call__

Method __call__

utils.py:205–219  ·  view source on GitHub ↗
(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True)

Source from the content-addressed store, hash-verified

203 self._scaler = torch.cuda.amp.GradScaler()
204
205 def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
206 self._scaler.scale(loss).backward(create_graph=create_graph)
207 if update_grad:
208 if clip_grad is not None:
209 assert parameters is not None
210 self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
211 norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
212 else:
213 self._scaler.unscale_(optimizer)
214 norm = ampscaler_get_grad_norm(parameters)
215 self._scaler.step(optimizer)
216 self._scaler.update()
217 else:
218 norm = None
219 return norm
220
221 def state_dict(self):
222 return self._scaler.state_dict()

Callers

nothing calls this directly

Calls 2

ampscaler_get_grad_normFunction · 0.85
backwardMethod · 0.45

Tested by

no test coverage detected