MCPcopy
hub / github.com/dmlc/dgl / test_save_data

Function test_save_data

tests/python/pytorch/graphbolt/internal/test_utils.py:64–89  ·  view source on GitHub ↗
(data_fmt, save_fmt, contiguous)

Source from the content-addressed store, hash-verified

62 ],
63)
64def test_save_data(data_fmt, save_fmt, contiguous):
65 with tempfile.TemporaryDirectory() as test_dir:
66 data = np.array([[1, 2, 4], [2, 5, 3]])
67 if not contiguous:
68 data = np.asfortranarray(data)
69 tensor_data = torch.from_numpy(data)
70 type_name = "pt" if save_fmt == "torch" else "npy"
71 save_file_name = os.path.join(test_dir, f"save_data.{type_name}")
72 # Step1. Save the data.
73 if data_fmt == "torch":
74 internal.save_data(tensor_data, save_file_name, save_fmt)
75 elif data_fmt == "numpy":
76 internal.save_data(data, save_file_name, save_fmt)
77
78 # Step2. Load the data.
79 if save_fmt == "torch":
80 loaded_data = torch.load(save_file_name)
81 assert loaded_data.is_contiguous()
82 assert torch.equal(tensor_data, loaded_data)
83 elif save_fmt == "numpy":
84 loaded_data = np.load(save_file_name)
85 # Checks if the loaded data is C-contiguous.
86 assert loaded_data.flags["C_CONTIGUOUS"]
87 assert np.array_equal(tensor_data.numpy(), loaded_data)
88
89 data = tensor_data = loaded_data = None
90
91
92@pytest.mark.parametrize("fmt", ["torch", "numpy"])

Callers

nothing calls this directly

Calls 2

joinMethod · 0.45
loadMethod · 0.45

Tested by

no test coverage detected