(self, grad_scale=1.0)
| 215 | return True |
| 216 | |
| 217 | def _compute_norms(self, grad_scale=1.0): |
| 218 | grad_norm = 0.0 |
| 219 | param_norm = 0.0 |
| 220 | for p in self.master_params: |
| 221 | with th.no_grad(): |
| 222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 |
| 223 | if p.grad is not None: |
| 224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 |
| 225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) |
| 226 | |
| 227 | def master_params_to_state_dict(self, master_params): |
| 228 | return master_params_to_state_dict( |
no outgoing calls
no test coverage detected