(self, got, expected, atol, rtol)
| 32 | return [p.grad for p in get_parameters(model)] |
| 33 | |
| 34 | def _assert_close(self, got, expected, atol, rtol): |
| 35 | for g, e in zip(got, expected): |
| 36 | self.assertTrue(g.allclose(e, atol=atol, rtol=rtol).item(), f"grad mismatch (max abs diff {(g - e).abs().max().item()})") |
| 37 | |
| 38 | def _assert_match(self, model, xs, atol, rtol): |
| 39 | self._assert_close(self._run_with_apply_grad(model, xs), self._run_reference(model, xs), atol, rtol) |
no test coverage detected