MCPcopy Index your code
hub / github.com/dmlc/dgl / _test_construct_graphs_homo

Function _test_construct_graphs_homo

tests/python/common/data/test_data.py:512–567  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

510
511
512def _test_construct_graphs_homo():
513 from dgl.data.csv_dataset_base import (
514 DGLGraphConstructor,
515 EdgeData,
516 NodeData,
517 )
518
519 # node_id could be non-sorted, non-numeric.
520 num_nodes = 100
521 num_edges = 1000
522 num_dims = 3
523 node_ids = np.random.choice(
524 np.arange(num_nodes * 2), size=num_nodes, replace=False
525 )
526 assert len(node_ids) == num_nodes
527 # to be non-sorted
528 np.random.shuffle(node_ids)
529 # to be non-numeric
530 node_ids = ["id_{}".format(id) for id in node_ids]
531 t_ndata = {
532 "feat": np.random.rand(num_nodes, num_dims),
533 "label": np.random.randint(2, size=num_nodes),
534 }
535 _, u_indices = np.unique(node_ids, return_index=True)
536 ndata = {
537 "feat": t_ndata["feat"][u_indices],
538 "label": t_ndata["label"][u_indices],
539 }
540 node_data = NodeData(node_ids, t_ndata)
541 src_ids = np.random.choice(node_ids, size=num_edges)
542 dst_ids = np.random.choice(node_ids, size=num_edges)
543 edata = {
544 "feat": np.random.rand(num_edges, num_dims),
545 "label": np.random.randint(2, size=num_edges),
546 }
547 edge_data = EdgeData(src_ids, dst_ids, edata)
548 graphs, data_dict = DGLGraphConstructor.construct_graphs(
549 node_data, edge_data
550 )
551 assert len(graphs) == 1
552 assert len(data_dict) == 0
553 g = graphs[0]
554 assert g.is_homogeneous
555 assert g.num_nodes() == num_nodes
556 assert g.num_edges() == num_edges
557
558 def assert_data(lhs, rhs):
559 for key, value in lhs.items():
560 assert key in rhs
561 assert F.dtype(rhs[key]) != F.float64
562 assert F.array_equal(
563 F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
564 )
565
566 assert_data(ndata, g.ndata)
567 assert_data(edata, g.edata)
568
569

Callers 1

test_csvdatasetFunction · 0.85

Calls 8

NodeDataClass · 0.90
EdgeDataClass · 0.90
assert_dataFunction · 0.85
formatMethod · 0.80
construct_graphsMethod · 0.80
shuffleMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected