r"""Return the gradient norm of model. Args: model (PyTorch module): Your network.
(model)
| 199 | |
| 200 | |
| 201 | def gradient_norm(model): |
| 202 | r"""Return the gradient norm of model. |
| 203 | |
| 204 | Args: |
| 205 | model (PyTorch module): Your network. |
| 206 | |
| 207 | """ |
| 208 | total_norm = 0 |
| 209 | for p in model.parameters(): |
| 210 | if p.grad is not None: |
| 211 | param_norm = p.grad.norm(2) |
| 212 | total_norm += param_norm.item() ** 2 |
| 213 | return total_norm ** (1. / 2) |
| 214 | |
| 215 | |
| 216 | def random_shift(x, offset=0.05, mode='bilinear', padding_mode='reflection'): |