(idtype)
| 23 | |
| 24 | @parametrize_idtype |
| 25 | def test_to_block(idtype): |
| 26 | def check(g, bg, ntype, etype, dst_nodes, include_dst_in_src=True): |
| 27 | if dst_nodes is not None: |
| 28 | assert F.array_equal(bg.dstnodes[ntype].data[dgl.NID], dst_nodes) |
| 29 | n_dst_nodes = bg.num_nodes("DST/" + ntype) |
| 30 | if include_dst_in_src: |
| 31 | assert F.array_equal( |
| 32 | bg.srcnodes[ntype].data[dgl.NID][:n_dst_nodes], |
| 33 | bg.dstnodes[ntype].data[dgl.NID], |
| 34 | ) |
| 35 | |
| 36 | g = g[etype] |
| 37 | bg = bg[etype] |
| 38 | induced_src = bg.srcdata[dgl.NID] |
| 39 | induced_dst = bg.dstdata[dgl.NID] |
| 40 | induced_eid = bg.edata[dgl.EID] |
| 41 | |
| 42 | bg_src, bg_dst = bg.all_edges(order="eid") |
| 43 | src_ans, dst_ans = g.all_edges(order="eid") |
| 44 | |
| 45 | induced_src_bg = F.gather_row(induced_src, bg_src) |
| 46 | induced_dst_bg = F.gather_row(induced_dst, bg_dst) |
| 47 | induced_src_ans = F.gather_row(src_ans, induced_eid) |
| 48 | induced_dst_ans = F.gather_row(dst_ans, induced_eid) |
| 49 | |
| 50 | assert F.array_equal(induced_src_bg, induced_src_ans) |
| 51 | assert F.array_equal(induced_dst_bg, induced_dst_ans) |
| 52 | |
| 53 | def checkall(g, bg, dst_nodes, include_dst_in_src=True): |
| 54 | for etype in g.etypes: |
| 55 | ntype = g.to_canonical_etype(etype)[2] |
| 56 | if dst_nodes is not None and ntype in dst_nodes: |
| 57 | check(g, bg, ntype, etype, dst_nodes[ntype], include_dst_in_src) |
| 58 | else: |
| 59 | check(g, bg, ntype, etype, None, include_dst_in_src) |
| 60 | |
| 61 | # homogeneous graph |
| 62 | g = dgl.graph( |
| 63 | (F.tensor([1, 2], dtype=idtype), F.tensor([2, 3], dtype=idtype)) |
| 64 | ) |
| 65 | dst_nodes = F.tensor([3, 2], dtype=idtype) |
| 66 | bg = dgl.to_block(g, dst_nodes=dst_nodes) |
| 67 | check(g, bg, "_N", "_E", dst_nodes) |
| 68 | |
| 69 | src_nodes = bg.srcnodes["_N"].data[dgl.NID] |
| 70 | bg = dgl.to_block(g, dst_nodes=dst_nodes, src_nodes=src_nodes) |
| 71 | check(g, bg, "_N", "_E", dst_nodes) |
| 72 | |
| 73 | # heterogeneous graph |
| 74 | g = dgl.heterograph( |
| 75 | { |
| 76 | ("A", "AA", "A"): ([0, 2, 1, 3], [1, 3, 2, 4]), |
| 77 | ("A", "AB", "B"): ([0, 1, 3, 1], [1, 3, 5, 6]), |
| 78 | ("B", "BA", "A"): ([2, 3], [3, 2]), |
| 79 | }, |
| 80 | idtype=idtype, |
| 81 | device=F.ctx(), |
| 82 | ) |
nothing calls this directly
no test coverage detected