Test that a python value matches the recorded contents of a file based on a "check" name. The value must be pickable with `torch.save`. This file is placed in the 'expect' directory in the same directory as the test script. You can automatically update the recorded test output us
(output, name, prec=None, atol=None, rtol=None)
| 132 | |
| 133 | |
| 134 | def _assert_expected(output, name, prec=None, atol=None, rtol=None): |
| 135 | """Test that a python value matches the recorded contents of a file |
| 136 | based on a "check" name. The value must be |
| 137 | pickable with `torch.save`. This file |
| 138 | is placed in the 'expect' directory in the same directory |
| 139 | as the test script. You can automatically update the recorded test |
| 140 | output using an EXPECTTEST_ACCEPT=1 env variable. |
| 141 | """ |
| 142 | expected_file = _get_expected_file(name) |
| 143 | |
| 144 | if ACCEPT: |
| 145 | filename = {os.path.basename(expected_file)} |
| 146 | print(f"Accepting updated output for {filename}:\n\n{output}") |
| 147 | torch.save(output, expected_file) |
| 148 | MAX_PICKLE_SIZE = 50 * 1000 # 50 KB |
| 149 | binary_size = os.path.getsize(expected_file) |
| 150 | if binary_size > MAX_PICKLE_SIZE: |
| 151 | raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb") |
| 152 | else: |
| 153 | expected = torch.load(expected_file, weights_only=True) |
| 154 | rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally |
| 155 | atol = atol or prec |
| 156 | torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False, check_device=False) |
| 157 | |
| 158 | |
| 159 | def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None): |
no test coverage detected
searching dependent graphs…