()
| 1095 | |
| 1096 | |
| 1097 | def _test_load_graph_data_from_csv(): |
| 1098 | from dgl.data.csv_dataset_base import ( |
| 1099 | DefaultDataParser, |
| 1100 | GraphData, |
| 1101 | MetaGraph, |
| 1102 | ) |
| 1103 | |
| 1104 | with tempfile.TemporaryDirectory() as test_dir: |
| 1105 | num_graphs = 100 |
| 1106 | # minimum |
| 1107 | df = pd.DataFrame({"graph_id": np.arange(num_graphs)}) |
| 1108 | csv_path = os.path.join(test_dir, "graph.csv") |
| 1109 | df.to_csv(csv_path, index=False) |
| 1110 | meta_graph = MetaGraph(file_name=csv_path) |
| 1111 | graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser()) |
| 1112 | assert np.array_equal(df["graph_id"], graph_data.graph_id) |
| 1113 | assert len(graph_data.data) == 0 |
| 1114 | |
| 1115 | # common case |
| 1116 | df = pd.DataFrame( |
| 1117 | { |
| 1118 | "graph_id": np.arange(num_graphs), |
| 1119 | "label": np.random.randint(3, size=num_graphs), |
| 1120 | } |
| 1121 | ) |
| 1122 | csv_path = os.path.join(test_dir, "graph.csv") |
| 1123 | df.to_csv(csv_path, index=False) |
| 1124 | meta_graph = MetaGraph(file_name=csv_path) |
| 1125 | graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser()) |
| 1126 | assert np.array_equal(df["graph_id"], graph_data.graph_id) |
| 1127 | assert len(graph_data.data) == 1 |
| 1128 | assert np.array_equal(df["label"], graph_data.data["label"]) |
| 1129 | |
| 1130 | # add more fields into graph.csv |
| 1131 | df = pd.DataFrame( |
| 1132 | { |
| 1133 | "graph_id": np.arange(num_graphs), |
| 1134 | "feat": np.random.randint(3, size=num_graphs), |
| 1135 | "label": np.random.randint(3, size=num_graphs), |
| 1136 | } |
| 1137 | ) |
| 1138 | csv_path = os.path.join(test_dir, "graph.csv") |
| 1139 | df.to_csv(csv_path, index=False) |
| 1140 | meta_graph = MetaGraph(file_name=csv_path) |
| 1141 | graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser()) |
| 1142 | assert np.array_equal(df["graph_id"], graph_data.graph_id) |
| 1143 | assert len(graph_data.data) == 2 |
| 1144 | assert np.array_equal(df["feat"], graph_data.data["feat"]) |
| 1145 | assert np.array_equal(df["label"], graph_data.data["label"]) |
| 1146 | |
| 1147 | # required header is missing |
| 1148 | df = pd.DataFrame({"label": np.random.randint(3, size=num_graphs)}) |
| 1149 | csv_path = os.path.join(test_dir, "graph.csv") |
| 1150 | df.to_csv(csv_path, index=False) |
| 1151 | meta_graph = MetaGraph(file_name=csv_path) |
| 1152 | expect_except = False |
| 1153 | try: |
| 1154 | GraphData.load_from_csv(meta_graph, DefaultDataParser()) |
no test coverage detected