MCPcopy
hub / github.com/dmlc/dgl / _test_load_graph_data_from_csv

Function _test_load_graph_data_from_csv

tests/python/common/data/test_data.py:1097–1157  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

1095
1096
1097def _test_load_graph_data_from_csv():
1098 from dgl.data.csv_dataset_base import (
1099 DefaultDataParser,
1100 GraphData,
1101 MetaGraph,
1102 )
1103
1104 with tempfile.TemporaryDirectory() as test_dir:
1105 num_graphs = 100
1106 # minimum
1107 df = pd.DataFrame({"graph_id": np.arange(num_graphs)})
1108 csv_path = os.path.join(test_dir, "graph.csv")
1109 df.to_csv(csv_path, index=False)
1110 meta_graph = MetaGraph(file_name=csv_path)
1111 graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
1112 assert np.array_equal(df["graph_id"], graph_data.graph_id)
1113 assert len(graph_data.data) == 0
1114
1115 # common case
1116 df = pd.DataFrame(
1117 {
1118 "graph_id": np.arange(num_graphs),
1119 "label": np.random.randint(3, size=num_graphs),
1120 }
1121 )
1122 csv_path = os.path.join(test_dir, "graph.csv")
1123 df.to_csv(csv_path, index=False)
1124 meta_graph = MetaGraph(file_name=csv_path)
1125 graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
1126 assert np.array_equal(df["graph_id"], graph_data.graph_id)
1127 assert len(graph_data.data) == 1
1128 assert np.array_equal(df["label"], graph_data.data["label"])
1129
1130 # add more fields into graph.csv
1131 df = pd.DataFrame(
1132 {
1133 "graph_id": np.arange(num_graphs),
1134 "feat": np.random.randint(3, size=num_graphs),
1135 "label": np.random.randint(3, size=num_graphs),
1136 }
1137 )
1138 csv_path = os.path.join(test_dir, "graph.csv")
1139 df.to_csv(csv_path, index=False)
1140 meta_graph = MetaGraph(file_name=csv_path)
1141 graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
1142 assert np.array_equal(df["graph_id"], graph_data.graph_id)
1143 assert len(graph_data.data) == 2
1144 assert np.array_equal(df["feat"], graph_data.data["feat"])
1145 assert np.array_equal(df["label"], graph_data.data["label"])
1146
1147 # required header is missing
1148 df = pd.DataFrame({"label": np.random.randint(3, size=num_graphs)})
1149 csv_path = os.path.join(test_dir, "graph.csv")
1150 df.to_csv(csv_path, index=False)
1151 meta_graph = MetaGraph(file_name=csv_path)
1152 expect_except = False
1153 try:
1154 GraphData.load_from_csv(meta_graph, DefaultDataParser())

Callers 1

test_csvdatasetFunction · 0.85

Calls 4

MetaGraphClass · 0.90
DefaultDataParserClass · 0.90
joinMethod · 0.45
load_from_csvMethod · 0.45

Tested by

no test coverage detected