| 216 | |
| 217 | |
| 218 | def print_tensor_test( |
| 219 | tensor, |
| 220 | limit_to_slices=None, |
| 221 | max_torch_print=None, |
| 222 | filename="test_corrections.txt", |
| 223 | expected_tensor_name="expected_slice", |
| 224 | ): |
| 225 | if max_torch_print: |
| 226 | torch.set_printoptions(threshold=10_000) |
| 227 | |
| 228 | test_name = os.environ.get("PYTEST_CURRENT_TEST") |
| 229 | if not torch.is_tensor(tensor): |
| 230 | tensor = torch.from_numpy(tensor) |
| 231 | if limit_to_slices: |
| 232 | tensor = tensor[0, -3:, -3:, -1] |
| 233 | |
| 234 | tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "") |
| 235 | # format is usually: |
| 236 | # expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161]) |
| 237 | output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array") |
| 238 | test_file, test_class, test_fn = test_name.split("::") |
| 239 | test_fn = test_fn.split()[0] |
| 240 | with open(filename, "a") as f: |
| 241 | print("::".join([test_file, test_class, test_fn, output_str]), file=f) |
| 242 | |
| 243 | |
| 244 | def get_tests_dir(append_path=None): |