| 166 | |
| 167 | |
| 168 | def print_tensor_test( |
| 169 | tensor, |
| 170 | limit_to_slices=None, |
| 171 | max_torch_print=None, |
| 172 | filename="test_corrections.txt", |
| 173 | expected_tensor_name="expected_slice", |
| 174 | ): |
| 175 | if max_torch_print: |
| 176 | torch.set_printoptions(threshold=10_000) |
| 177 | |
| 178 | test_name = os.environ.get("PYTEST_CURRENT_TEST") |
| 179 | if not torch.is_tensor(tensor): |
| 180 | tensor = torch.from_numpy(tensor) |
| 181 | if limit_to_slices: |
| 182 | tensor = tensor[0, -3:, -3:, -1] |
| 183 | |
| 184 | tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "") |
| 185 | # format is usually: |
| 186 | # expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161]) |
| 187 | output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array") |
| 188 | test_file, test_class, test_fn = test_name.split("::") |
| 189 | test_fn = test_fn.split()[0] |
| 190 | with open(filename, "a") as f: |
| 191 | print("::".join([test_file, test_class, test_fn, output_str]), file=f) |
| 192 | |
| 193 | |
| 194 | def get_tests_dir(append_path=None): |