Set requires_grad flag for all parameters in a model.
(model, flag=True)
| 200 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) |
| 201 | |
| 202 | def requires_grad(model, flag=True): |
| 203 | """ |
| 204 | Set requires_grad flag for all parameters in a model. |
| 205 | """ |
| 206 | for p in model.parameters(): |
| 207 | p.requires_grad = flag |
| 208 | |
| 209 | def cleanup(): |
| 210 | """ |