()
| 944 | |
| 945 | |
| 946 | def _test_load_node_data_from_csv(): |
| 947 | from dgl.data.csv_dataset_base import DefaultDataParser, MetaNode, NodeData |
| 948 | |
| 949 | with tempfile.TemporaryDirectory() as test_dir: |
| 950 | num_nodes = 100 |
| 951 | # minimum |
| 952 | df = pd.DataFrame({"node_id": np.arange(num_nodes)}) |
| 953 | csv_path = os.path.join(test_dir, "nodes.csv") |
| 954 | df.to_csv(csv_path, index=False) |
| 955 | meta_node = MetaNode(file_name=csv_path) |
| 956 | node_data = NodeData.load_from_csv(meta_node, DefaultDataParser()) |
| 957 | assert np.array_equal(df["node_id"], node_data.id) |
| 958 | assert len(node_data.data) == 0 |
| 959 | |
| 960 | # common case |
| 961 | df = pd.DataFrame( |
| 962 | { |
| 963 | "node_id": np.arange(num_nodes), |
| 964 | "label": np.random.randint(3, size=num_nodes), |
| 965 | } |
| 966 | ) |
| 967 | csv_path = os.path.join(test_dir, "nodes.csv") |
| 968 | df.to_csv(csv_path, index=False) |
| 969 | meta_node = MetaNode(file_name=csv_path) |
| 970 | node_data = NodeData.load_from_csv(meta_node, DefaultDataParser()) |
| 971 | assert np.array_equal(df["node_id"], node_data.id) |
| 972 | assert len(node_data.data) == 1 |
| 973 | assert np.array_equal(df["label"], node_data.data["label"]) |
| 974 | assert np.array_equal(np.full(num_nodes, 0), node_data.graph_id) |
| 975 | assert node_data.type == "_V" |
| 976 | |
| 977 | # add more fields into nodes.csv |
| 978 | df = pd.DataFrame( |
| 979 | { |
| 980 | "node_id": np.arange(num_nodes), |
| 981 | "label": np.random.randint(3, size=num_nodes), |
| 982 | "graph_id": np.full(num_nodes, 1), |
| 983 | } |
| 984 | ) |
| 985 | csv_path = os.path.join(test_dir, "nodes.csv") |
| 986 | df.to_csv(csv_path, index=False) |
| 987 | meta_node = MetaNode(file_name=csv_path) |
| 988 | node_data = NodeData.load_from_csv(meta_node, DefaultDataParser()) |
| 989 | assert np.array_equal(df["node_id"], node_data.id) |
| 990 | assert len(node_data.data) == 1 |
| 991 | assert np.array_equal(df["label"], node_data.data["label"]) |
| 992 | assert np.array_equal(df["graph_id"], node_data.graph_id) |
| 993 | assert node_data.type == "_V" |
| 994 | |
| 995 | # required header is missing |
| 996 | df = pd.DataFrame({"label": np.random.randint(3, size=num_nodes)}) |
| 997 | csv_path = os.path.join(test_dir, "nodes.csv") |
| 998 | df.to_csv(csv_path, index=False) |
| 999 | meta_node = MetaNode(file_name=csv_path) |
| 1000 | expect_except = False |
| 1001 | try: |
| 1002 | NodeData.load_from_csv(meta_node, DefaultDataParser()) |
| 1003 | except: |
no test coverage detected