MCPcopy
hub / github.com/dmlc/dgl / test_to_block

Function test_to_block

tests/python/common/transforms/test_to_block.py:25–192  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

23
24@parametrize_idtype
25def test_to_block(idtype):
26 def check(g, bg, ntype, etype, dst_nodes, include_dst_in_src=True):
27 if dst_nodes is not None:
28 assert F.array_equal(bg.dstnodes[ntype].data[dgl.NID], dst_nodes)
29 n_dst_nodes = bg.num_nodes("DST/" + ntype)
30 if include_dst_in_src:
31 assert F.array_equal(
32 bg.srcnodes[ntype].data[dgl.NID][:n_dst_nodes],
33 bg.dstnodes[ntype].data[dgl.NID],
34 )
35
36 g = g[etype]
37 bg = bg[etype]
38 induced_src = bg.srcdata[dgl.NID]
39 induced_dst = bg.dstdata[dgl.NID]
40 induced_eid = bg.edata[dgl.EID]
41
42 bg_src, bg_dst = bg.all_edges(order="eid")
43 src_ans, dst_ans = g.all_edges(order="eid")
44
45 induced_src_bg = F.gather_row(induced_src, bg_src)
46 induced_dst_bg = F.gather_row(induced_dst, bg_dst)
47 induced_src_ans = F.gather_row(src_ans, induced_eid)
48 induced_dst_ans = F.gather_row(dst_ans, induced_eid)
49
50 assert F.array_equal(induced_src_bg, induced_src_ans)
51 assert F.array_equal(induced_dst_bg, induced_dst_ans)
52
53 def checkall(g, bg, dst_nodes, include_dst_in_src=True):
54 for etype in g.etypes:
55 ntype = g.to_canonical_etype(etype)[2]
56 if dst_nodes is not None and ntype in dst_nodes:
57 check(g, bg, ntype, etype, dst_nodes[ntype], include_dst_in_src)
58 else:
59 check(g, bg, ntype, etype, None, include_dst_in_src)
60
61 # homogeneous graph
62 g = dgl.graph(
63 (F.tensor([1, 2], dtype=idtype), F.tensor([2, 3], dtype=idtype))
64 )
65 dst_nodes = F.tensor([3, 2], dtype=idtype)
66 bg = dgl.to_block(g, dst_nodes=dst_nodes)
67 check(g, bg, "_N", "_E", dst_nodes)
68
69 src_nodes = bg.srcnodes["_N"].data[dgl.NID]
70 bg = dgl.to_block(g, dst_nodes=dst_nodes, src_nodes=src_nodes)
71 check(g, bg, "_N", "_E", dst_nodes)
72
73 # heterogeneous graph
74 g = dgl.heterograph(
75 {
76 ("A", "AA", "A"): ([0, 2, 1, 3], [1, 3, 2, 4]),
77 ("A", "AB", "B"): ([0, 1, 3, 1], [1, 3, 5, 6]),
78 ("B", "BA", "A"): ([2, 3], [3, 2]),
79 },
80 idtype=idtype,
81 device=F.ctx(),
82 )

Callers

nothing calls this directly

Calls 9

check_featuresFunction · 0.85
checkallFunction · 0.85
number_of_src_nodesMethod · 0.80
number_of_dst_nodesMethod · 0.80
checkFunction · 0.70
graphMethod · 0.45
ctxMethod · 0.45
num_nodesMethod · 0.45
keysMethod · 0.45

Tested by

no test coverage detected