Zero fp32 and fp16 parameter grads.
(self, set_grads_to_None=False)
| 250 | raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().") |
| 251 | |
| 252 | def zero_grad(self, set_grads_to_None=False): |
| 253 | """ |
| 254 | Zero fp32 and fp16 parameter grads. |
| 255 | """ |
| 256 | # In principle, only the .grad attributes of the model params need to be zeroed, |
| 257 | # because gradients are copied into the FP32 master params. However, we zero |
| 258 | # all gradients owned by the optimizer, just to be safe: |
| 259 | for group in self.optimizer.param_groups: |
| 260 | for p in group['params']: |
| 261 | if set_grads_to_None: |
| 262 | p.grad = None |
| 263 | else: |
| 264 | if p.grad is not None: |
| 265 | p.grad.detach_() |
| 266 | p.grad.zero_() |
| 267 | |
| 268 | # Zero fp16 gradients owned by the model: |
| 269 | for fp16_group in self.fp16_groups: |
| 270 | for param in fp16_group: |
| 271 | if set_grads_to_None: |
| 272 | param.grad = None |
| 273 | else: |
| 274 | if param.grad is not None: |
| 275 | param.grad.detach_() # as in torch.optim.optimizer.zero_grad() |
| 276 | param.grad.zero_() |
| 277 | |
| 278 | def _check_overflow(self): |
| 279 | params = [] |
no outgoing calls
no test coverage detected