(g, extra_hops)
| 767 | |
| 768 | |
| 769 | def check_metis_partition(g, extra_hops): |
| 770 | subgs = dgl.transforms.metis_partition(g, 4, extra_cached_hops=extra_hops) |
| 771 | num_inner_nodes = 0 |
| 772 | num_inner_edges = 0 |
| 773 | if subgs is not None: |
| 774 | for part_id, subg in subgs.items(): |
| 775 | lnode_ids = np.nonzero(F.asnumpy(subg.ndata["inner_node"]))[0] |
| 776 | ledge_ids = np.nonzero(F.asnumpy(subg.edata["inner_edge"]))[0] |
| 777 | num_inner_nodes += len(lnode_ids) |
| 778 | num_inner_edges += len(ledge_ids) |
| 779 | assert np.sum(F.asnumpy(subg.ndata["part_id"]) == part_id) == len( |
| 780 | lnode_ids |
| 781 | ) |
| 782 | assert num_inner_nodes == g.num_nodes() |
| 783 | print(g.num_edges() - num_inner_edges) |
| 784 | |
| 785 | if extra_hops == 0: |
| 786 | return |
| 787 | |
| 788 | # partitions with node reshuffling |
| 789 | subgs = dgl.transforms.metis_partition( |
| 790 | g, 4, extra_cached_hops=extra_hops, reshuffle=True |
| 791 | ) |
| 792 | num_inner_nodes = 0 |
| 793 | num_inner_edges = 0 |
| 794 | edge_cnts = np.zeros((g.num_edges(),)) |
| 795 | if subgs is not None: |
| 796 | for part_id, subg in subgs.items(): |
| 797 | lnode_ids = np.nonzero(F.asnumpy(subg.ndata["inner_node"]))[0] |
| 798 | ledge_ids = np.nonzero(F.asnumpy(subg.edata["inner_edge"]))[0] |
| 799 | num_inner_nodes += len(lnode_ids) |
| 800 | num_inner_edges += len(ledge_ids) |
| 801 | assert np.sum(F.asnumpy(subg.ndata["part_id"]) == part_id) == len( |
| 802 | lnode_ids |
| 803 | ) |
| 804 | nids = F.asnumpy(subg.ndata[dgl.NID]) |
| 805 | |
| 806 | # ensure the local node Ids are contiguous. |
| 807 | parent_ids = F.asnumpy(subg.ndata[dgl.NID]) |
| 808 | parent_ids = parent_ids[: len(lnode_ids)] |
| 809 | assert np.all( |
| 810 | parent_ids == np.arange(parent_ids[0], parent_ids[-1] + 1) |
| 811 | ) |
| 812 | |
| 813 | # count the local edges. |
| 814 | parent_ids = F.asnumpy(subg.edata[dgl.EID])[ledge_ids] |
| 815 | edge_cnts[parent_ids] += 1 |
| 816 | |
| 817 | orig_ids = subg.ndata["orig_id"] |
| 818 | inner_node = F.asnumpy(subg.ndata["inner_node"]) |
| 819 | for nid in range(subg.num_nodes()): |
| 820 | neighs = subg.predecessors(nid) |
| 821 | old_neighs1 = F.gather_row(orig_ids, neighs) |
| 822 | old_nid = F.asnumpy(orig_ids[nid]) |
| 823 | old_neighs2 = g.predecessors(old_nid) |
| 824 | # If this is an inner node, it should have the full neighborhood. |
| 825 | if inner_node[nid]: |
| 826 | assert np.all( |
no test coverage detected