| 62 | ], |
| 63 | ) |
| 64 | def test_save_data(data_fmt, save_fmt, contiguous): |
| 65 | with tempfile.TemporaryDirectory() as test_dir: |
| 66 | data = np.array([[1, 2, 4], [2, 5, 3]]) |
| 67 | if not contiguous: |
| 68 | data = np.asfortranarray(data) |
| 69 | tensor_data = torch.from_numpy(data) |
| 70 | type_name = "pt" if save_fmt == "torch" else "npy" |
| 71 | save_file_name = os.path.join(test_dir, f"save_data.{type_name}") |
| 72 | # Step1. Save the data. |
| 73 | if data_fmt == "torch": |
| 74 | internal.save_data(tensor_data, save_file_name, save_fmt) |
| 75 | elif data_fmt == "numpy": |
| 76 | internal.save_data(data, save_file_name, save_fmt) |
| 77 | |
| 78 | # Step2. Load the data. |
| 79 | if save_fmt == "torch": |
| 80 | loaded_data = torch.load(save_file_name) |
| 81 | assert loaded_data.is_contiguous() |
| 82 | assert torch.equal(tensor_data, loaded_data) |
| 83 | elif save_fmt == "numpy": |
| 84 | loaded_data = np.load(save_file_name) |
| 85 | # Checks if the loaded data is C-contiguous. |
| 86 | assert loaded_data.flags["C_CONTIGUOUS"] |
| 87 | assert np.array_equal(tensor_data.numpy(), loaded_data) |
| 88 | |
| 89 | data = tensor_data = loaded_data = None |
| 90 | |
| 91 | |
| 92 | @pytest.mark.parametrize("fmt", ["torch", "numpy"]) |