MCPcopy
hub / github.com/dmlc/dgl / _test_construct_graphs_hetero

Function _test_construct_graphs_hetero

tests/python/common/data/test_data.py:570–643  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

568
569
570def _test_construct_graphs_hetero():
571 from dgl.data.csv_dataset_base import (
572 DGLGraphConstructor,
573 EdgeData,
574 NodeData,
575 )
576
577 # node_id/src_id/dst_id could be non-sorted, duplicated, non-numeric.
578 num_nodes = 100
579 num_edges = 1000
580 num_dims = 3
581 ntypes = ["user", "item"]
582 node_data = []
583 node_ids_dict = {}
584 ndata_dict = {}
585 for ntype in ntypes:
586 node_ids = np.random.choice(
587 np.arange(num_nodes * 2), size=num_nodes, replace=False
588 )
589 assert len(node_ids) == num_nodes
590 # to be non-sorted
591 np.random.shuffle(node_ids)
592 # to be non-numeric
593 node_ids = ["id_{}".format(id) for id in node_ids]
594 t_ndata = {
595 "feat": np.random.rand(num_nodes, num_dims),
596 "label": np.random.randint(2, size=num_nodes),
597 }
598 _, u_indices = np.unique(node_ids, return_index=True)
599 ndata = {
600 "feat": t_ndata["feat"][u_indices],
601 "label": t_ndata["label"][u_indices],
602 }
603 node_data.append(NodeData(node_ids, t_ndata, type=ntype))
604 node_ids_dict[ntype] = node_ids
605 ndata_dict[ntype] = ndata
606 etypes = [("user", "follow", "user"), ("user", "like", "item")]
607 edge_data = []
608 edata_dict = {}
609 for src_type, e_type, dst_type in etypes:
610 src_ids = np.random.choice(node_ids_dict[src_type], size=num_edges)
611 dst_ids = np.random.choice(node_ids_dict[dst_type], size=num_edges)
612 edata = {
613 "feat": np.random.rand(num_edges, num_dims),
614 "label": np.random.randint(2, size=num_edges),
615 }
616 edge_data.append(
617 EdgeData(src_ids, dst_ids, edata, type=(src_type, e_type, dst_type))
618 )
619 edata_dict[(src_type, e_type, dst_type)] = edata
620 graphs, data_dict = DGLGraphConstructor.construct_graphs(
621 node_data, edge_data
622 )
623 assert len(graphs) == 1
624 assert len(data_dict) == 0
625 g = graphs[0]
626 assert not g.is_homogeneous
627 assert g.num_nodes() == num_nodes * len(ntypes)

Callers 1

test_csvdatasetFunction · 0.85

Calls 9

NodeDataClass · 0.90
EdgeDataClass · 0.90
assert_dataFunction · 0.85
formatMethod · 0.80
appendMethod · 0.80
construct_graphsMethod · 0.80
shuffleMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected