(model_params)
| 131 | |
| 132 | |
| 133 | def zero_grad(model_params): |
| 134 | for param in model_params: |
| 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group |
| 136 | if param.grad is not None: |
| 137 | param.grad.detach_() |
| 138 | param.grad.zero_() |
| 139 | |
| 140 | |
| 141 | def param_grad_or_zeros(param): |