(hetero, empty_mask)
| 1156 | @pytest.mark.parametrize("hetero", [True, False]) |
| 1157 | @pytest.mark.parametrize("empty_mask", [True, False]) |
| 1158 | def test_split(hetero, empty_mask): |
| 1159 | if hetero: |
| 1160 | g = create_random_hetero() |
| 1161 | ntype = "n1" |
| 1162 | etype = "r1" |
| 1163 | else: |
| 1164 | g = create_random_graph(10000) |
| 1165 | ntype = "_N" |
| 1166 | etype = "_E" |
| 1167 | num_parts = 4 |
| 1168 | num_hops = 2 |
| 1169 | partition_graph( |
| 1170 | g, |
| 1171 | "dist_graph_test", |
| 1172 | num_parts, |
| 1173 | "/tmp/dist_graph", |
| 1174 | num_hops=num_hops, |
| 1175 | part_method="metis", |
| 1176 | ) |
| 1177 | |
| 1178 | mask_thd = 100 if empty_mask else 30 |
| 1179 | node_mask = np.random.randint(0, 100, size=g.num_nodes(ntype)) > mask_thd |
| 1180 | edge_mask = np.random.randint(0, 100, size=g.num_edges(etype)) > mask_thd |
| 1181 | selected_nodes = np.nonzero(node_mask)[0] |
| 1182 | selected_edges = np.nonzero(edge_mask)[0] |
| 1183 | |
| 1184 | # The code now collects the roles of all client processes and use the information |
| 1185 | # to determine how to split the workloads. Here is to simulate the multi-client |
| 1186 | # use case. |
| 1187 | def set_roles(num_clients): |
| 1188 | dgl.distributed.role.CUR_ROLE = "default" |
| 1189 | dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)} |
| 1190 | dgl.distributed.role.PER_ROLE_RANK["default"] = { |
| 1191 | i: i for i in range(num_clients) |
| 1192 | } |
| 1193 | |
| 1194 | for i in range(num_parts): |
| 1195 | set_roles(num_parts) |
| 1196 | part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( |
| 1197 | "/tmp/dist_graph/dist_graph_test.json", i |
| 1198 | ) |
| 1199 | local_nids = F.nonzero_1d(part_g.ndata["inner_node"]) |
| 1200 | local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids) |
| 1201 | if hetero: |
| 1202 | ntype_ids, nids = gpb.map_to_per_ntype(local_nids) |
| 1203 | local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0] |
| 1204 | else: |
| 1205 | local_nids = F.asnumpy(local_nids) |
| 1206 | nodes1 = np.intersect1d(selected_nodes, local_nids) |
| 1207 | nodes2 = node_split( |
| 1208 | node_mask, gpb, ntype=ntype, rank=i, force_even=False |
| 1209 | ) |
| 1210 | assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2))) |
| 1211 | for n in F.asnumpy(nodes2): |
| 1212 | assert n in local_nids |
| 1213 | |
| 1214 | set_roles(num_parts * 2) |
| 1215 | nodes3 = node_split( |
no test coverage detected