| 109 | @pytest.mark.parametrize("save_fmt", ["numpy", "torch"]) |
| 110 | @pytest.mark.parametrize("is_feature", [True, False]) |
| 111 | def test_copy_or_convert_data(data_fmt, save_fmt, is_feature): |
| 112 | with tempfile.TemporaryDirectory() as test_dir: |
| 113 | data = np.arange(10) |
| 114 | tensor_data = torch.from_numpy(data) |
| 115 | in_type_name = "npy" if data_fmt == "numpy" else "pt" |
| 116 | input_path = os.path.join(test_dir, f"data.{in_type_name}") |
| 117 | out_type_name = "npy" if save_fmt == "numpy" else "pt" |
| 118 | output_path = os.path.join(test_dir, f"out_data.{out_type_name}") |
| 119 | if data_fmt == "numpy": |
| 120 | np.save(input_path, data) |
| 121 | else: |
| 122 | torch.save(tensor_data, input_path) |
| 123 | if save_fmt == "torch": |
| 124 | with pytest.raises(AssertionError): |
| 125 | internal.copy_or_convert_data( |
| 126 | input_path, |
| 127 | output_path, |
| 128 | data_fmt, |
| 129 | save_fmt, |
| 130 | is_feature=is_feature, |
| 131 | ) |
| 132 | else: |
| 133 | internal.copy_or_convert_data( |
| 134 | input_path, |
| 135 | output_path, |
| 136 | data_fmt, |
| 137 | save_fmt, |
| 138 | is_feature=is_feature, |
| 139 | ) |
| 140 | if is_feature: |
| 141 | data = data.reshape(-1, 1) |
| 142 | tensor_data = tensor_data.reshape(-1, 1) |
| 143 | if save_fmt == "numpy": |
| 144 | out_data = np.load(output_path) |
| 145 | assert (data == out_data).all() |
| 146 | |
| 147 | data = None |
| 148 | tensor_data = None |
| 149 | out_data = None |
| 150 | |
| 151 | |
| 152 | @pytest.mark.parametrize("edge_fmt", ["csv", "numpy"]) |