()
| 1006 | |
| 1007 | |
| 1008 | def _test_load_edge_data_from_csv(): |
| 1009 | from dgl.data.csv_dataset_base import DefaultDataParser, EdgeData, MetaEdge |
| 1010 | |
| 1011 | with tempfile.TemporaryDirectory() as test_dir: |
| 1012 | num_nodes = 100 |
| 1013 | num_edges = 1000 |
| 1014 | # minimum |
| 1015 | df = pd.DataFrame( |
| 1016 | { |
| 1017 | "src_id": np.random.randint(num_nodes, size=num_edges), |
| 1018 | "dst_id": np.random.randint(num_nodes, size=num_edges), |
| 1019 | } |
| 1020 | ) |
| 1021 | csv_path = os.path.join(test_dir, "edges.csv") |
| 1022 | df.to_csv(csv_path, index=False) |
| 1023 | meta_edge = MetaEdge(file_name=csv_path) |
| 1024 | edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser()) |
| 1025 | assert np.array_equal(df["src_id"], edge_data.src) |
| 1026 | assert np.array_equal(df["dst_id"], edge_data.dst) |
| 1027 | assert len(edge_data.data) == 0 |
| 1028 | |
| 1029 | # common case |
| 1030 | df = pd.DataFrame( |
| 1031 | { |
| 1032 | "src_id": np.random.randint(num_nodes, size=num_edges), |
| 1033 | "dst_id": np.random.randint(num_nodes, size=num_edges), |
| 1034 | "label": np.random.randint(3, size=num_edges), |
| 1035 | } |
| 1036 | ) |
| 1037 | csv_path = os.path.join(test_dir, "edges.csv") |
| 1038 | df.to_csv(csv_path, index=False) |
| 1039 | meta_edge = MetaEdge(file_name=csv_path) |
| 1040 | edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser()) |
| 1041 | assert np.array_equal(df["src_id"], edge_data.src) |
| 1042 | assert np.array_equal(df["dst_id"], edge_data.dst) |
| 1043 | assert len(edge_data.data) == 1 |
| 1044 | assert np.array_equal(df["label"], edge_data.data["label"]) |
| 1045 | assert np.array_equal(np.full(num_edges, 0), edge_data.graph_id) |
| 1046 | assert edge_data.type == ("_V", "_E", "_V") |
| 1047 | |
| 1048 | # add more fields into edges.csv |
| 1049 | df = pd.DataFrame( |
| 1050 | { |
| 1051 | "src_id": np.random.randint(num_nodes, size=num_edges), |
| 1052 | "dst_id": np.random.randint(num_nodes, size=num_edges), |
| 1053 | "graph_id": np.arange(num_edges), |
| 1054 | "feat": np.random.randint(3, size=num_edges), |
| 1055 | "label": np.random.randint(3, size=num_edges), |
| 1056 | } |
| 1057 | ) |
| 1058 | csv_path = os.path.join(test_dir, "edges.csv") |
| 1059 | df.to_csv(csv_path, index=False) |
| 1060 | meta_edge = MetaEdge(file_name=csv_path) |
| 1061 | edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser()) |
| 1062 | assert np.array_equal(df["src_id"], edge_data.src) |
| 1063 | assert np.array_equal(df["dst_id"], edge_data.dst) |
| 1064 | assert len(edge_data.data) == 2 |
| 1065 | assert np.array_equal(df["feat"], edge_data.data["feat"]) |
no test coverage detected