MCPcopy
hub / github.com/dmlc/dgl / test_copy_or_convert_data

Function test_copy_or_convert_data

tests/python/pytorch/graphbolt/internal/test_utils.py:111–149  ·  view source on GitHub ↗
(data_fmt, save_fmt, is_feature)

Source from the content-addressed store, hash-verified

109@pytest.mark.parametrize("save_fmt", ["numpy", "torch"])
110@pytest.mark.parametrize("is_feature", [True, False])
111def test_copy_or_convert_data(data_fmt, save_fmt, is_feature):
112 with tempfile.TemporaryDirectory() as test_dir:
113 data = np.arange(10)
114 tensor_data = torch.from_numpy(data)
115 in_type_name = "npy" if data_fmt == "numpy" else "pt"
116 input_path = os.path.join(test_dir, f"data.{in_type_name}")
117 out_type_name = "npy" if save_fmt == "numpy" else "pt"
118 output_path = os.path.join(test_dir, f"out_data.{out_type_name}")
119 if data_fmt == "numpy":
120 np.save(input_path, data)
121 else:
122 torch.save(tensor_data, input_path)
123 if save_fmt == "torch":
124 with pytest.raises(AssertionError):
125 internal.copy_or_convert_data(
126 input_path,
127 output_path,
128 data_fmt,
129 save_fmt,
130 is_feature=is_feature,
131 )
132 else:
133 internal.copy_or_convert_data(
134 input_path,
135 output_path,
136 data_fmt,
137 save_fmt,
138 is_feature=is_feature,
139 )
140 if is_feature:
141 data = data.reshape(-1, 1)
142 tensor_data = tensor_data.reshape(-1, 1)
143 if save_fmt == "numpy":
144 out_data = np.load(output_path)
145 assert (data == out_data).all()
146
147 data = None
148 tensor_data = None
149 out_data = None
150
151
152@pytest.mark.parametrize("edge_fmt", ["csv", "numpy"])

Callers

nothing calls this directly

Calls 3

joinMethod · 0.45
saveMethod · 0.45
loadMethod · 0.45

Tested by

no test coverage detected