MCPcopy
hub / github.com/dmlc/dgl / _test_load_node_data_from_csv

Function _test_load_node_data_from_csv

tests/python/common/data/test_data.py:946–1005  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

944
945
946def _test_load_node_data_from_csv():
947 from dgl.data.csv_dataset_base import DefaultDataParser, MetaNode, NodeData
948
949 with tempfile.TemporaryDirectory() as test_dir:
950 num_nodes = 100
951 # minimum
952 df = pd.DataFrame({"node_id": np.arange(num_nodes)})
953 csv_path = os.path.join(test_dir, "nodes.csv")
954 df.to_csv(csv_path, index=False)
955 meta_node = MetaNode(file_name=csv_path)
956 node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
957 assert np.array_equal(df["node_id"], node_data.id)
958 assert len(node_data.data) == 0
959
960 # common case
961 df = pd.DataFrame(
962 {
963 "node_id": np.arange(num_nodes),
964 "label": np.random.randint(3, size=num_nodes),
965 }
966 )
967 csv_path = os.path.join(test_dir, "nodes.csv")
968 df.to_csv(csv_path, index=False)
969 meta_node = MetaNode(file_name=csv_path)
970 node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
971 assert np.array_equal(df["node_id"], node_data.id)
972 assert len(node_data.data) == 1
973 assert np.array_equal(df["label"], node_data.data["label"])
974 assert np.array_equal(np.full(num_nodes, 0), node_data.graph_id)
975 assert node_data.type == "_V"
976
977 # add more fields into nodes.csv
978 df = pd.DataFrame(
979 {
980 "node_id": np.arange(num_nodes),
981 "label": np.random.randint(3, size=num_nodes),
982 "graph_id": np.full(num_nodes, 1),
983 }
984 )
985 csv_path = os.path.join(test_dir, "nodes.csv")
986 df.to_csv(csv_path, index=False)
987 meta_node = MetaNode(file_name=csv_path)
988 node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
989 assert np.array_equal(df["node_id"], node_data.id)
990 assert len(node_data.data) == 1
991 assert np.array_equal(df["label"], node_data.data["label"])
992 assert np.array_equal(df["graph_id"], node_data.graph_id)
993 assert node_data.type == "_V"
994
995 # required header is missing
996 df = pd.DataFrame({"label": np.random.randint(3, size=num_nodes)})
997 csv_path = os.path.join(test_dir, "nodes.csv")
998 df.to_csv(csv_path, index=False)
999 meta_node = MetaNode(file_name=csv_path)
1000 expect_except = False
1001 try:
1002 NodeData.load_from_csv(meta_node, DefaultDataParser())
1003 except:

Callers 1

test_csvdatasetFunction · 0.85

Calls 4

MetaNodeClass · 0.90
DefaultDataParserClass · 0.90
joinMethod · 0.45
load_from_csvMethod · 0.45

Tested by

no test coverage detected