MCPcopy
hub / github.com/dmlc/dgl / _test_load_edge_data_from_csv

Function _test_load_edge_data_from_csv

tests/python/common/data/test_data.py:1008–1094  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

1006
1007
1008def _test_load_edge_data_from_csv():
1009 from dgl.data.csv_dataset_base import DefaultDataParser, EdgeData, MetaEdge
1010
1011 with tempfile.TemporaryDirectory() as test_dir:
1012 num_nodes = 100
1013 num_edges = 1000
1014 # minimum
1015 df = pd.DataFrame(
1016 {
1017 "src_id": np.random.randint(num_nodes, size=num_edges),
1018 "dst_id": np.random.randint(num_nodes, size=num_edges),
1019 }
1020 )
1021 csv_path = os.path.join(test_dir, "edges.csv")
1022 df.to_csv(csv_path, index=False)
1023 meta_edge = MetaEdge(file_name=csv_path)
1024 edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
1025 assert np.array_equal(df["src_id"], edge_data.src)
1026 assert np.array_equal(df["dst_id"], edge_data.dst)
1027 assert len(edge_data.data) == 0
1028
1029 # common case
1030 df = pd.DataFrame(
1031 {
1032 "src_id": np.random.randint(num_nodes, size=num_edges),
1033 "dst_id": np.random.randint(num_nodes, size=num_edges),
1034 "label": np.random.randint(3, size=num_edges),
1035 }
1036 )
1037 csv_path = os.path.join(test_dir, "edges.csv")
1038 df.to_csv(csv_path, index=False)
1039 meta_edge = MetaEdge(file_name=csv_path)
1040 edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
1041 assert np.array_equal(df["src_id"], edge_data.src)
1042 assert np.array_equal(df["dst_id"], edge_data.dst)
1043 assert len(edge_data.data) == 1
1044 assert np.array_equal(df["label"], edge_data.data["label"])
1045 assert np.array_equal(np.full(num_edges, 0), edge_data.graph_id)
1046 assert edge_data.type == ("_V", "_E", "_V")
1047
1048 # add more fields into edges.csv
1049 df = pd.DataFrame(
1050 {
1051 "src_id": np.random.randint(num_nodes, size=num_edges),
1052 "dst_id": np.random.randint(num_nodes, size=num_edges),
1053 "graph_id": np.arange(num_edges),
1054 "feat": np.random.randint(3, size=num_edges),
1055 "label": np.random.randint(3, size=num_edges),
1056 }
1057 )
1058 csv_path = os.path.join(test_dir, "edges.csv")
1059 df.to_csv(csv_path, index=False)
1060 meta_edge = MetaEdge(file_name=csv_path)
1061 edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
1062 assert np.array_equal(df["src_id"], edge_data.src)
1063 assert np.array_equal(df["dst_id"], edge_data.dst)
1064 assert len(edge_data.data) == 2
1065 assert np.array_equal(df["feat"], edge_data.data["feat"])

Callers 1

test_csvdatasetFunction · 0.85

Calls 4

MetaEdgeClass · 0.90
DefaultDataParserClass · 0.90
joinMethod · 0.45
load_from_csvMethod · 0.45

Tested by

no test coverage detected