(empty_mask)
| 1251 | @unittest.skipIf(os.name == "nt", reason="Do not support windows yet") |
| 1252 | @pytest.mark.parametrize("empty_mask", [True, False]) |
| 1253 | def test_split_even(empty_mask): |
| 1254 | g = create_random_graph(10000) |
| 1255 | num_parts = 4 |
| 1256 | num_hops = 2 |
| 1257 | partition_graph( |
| 1258 | g, |
| 1259 | "dist_graph_test", |
| 1260 | num_parts, |
| 1261 | "/tmp/dist_graph", |
| 1262 | num_hops=num_hops, |
| 1263 | part_method="metis", |
| 1264 | ) |
| 1265 | |
| 1266 | mask_thd = 100 if empty_mask else 30 |
| 1267 | node_mask = np.random.randint(0, 100, size=g.num_nodes()) > mask_thd |
| 1268 | edge_mask = np.random.randint(0, 100, size=g.num_edges()) > mask_thd |
| 1269 | all_nodes1 = [] |
| 1270 | all_nodes2 = [] |
| 1271 | all_edges1 = [] |
| 1272 | all_edges2 = [] |
| 1273 | |
| 1274 | # The code now collects the roles of all client processes and use the information |
| 1275 | # to determine how to split the workloads. Here is to simulate the multi-client |
| 1276 | # use case. |
| 1277 | def set_roles(num_clients): |
| 1278 | dgl.distributed.role.CUR_ROLE = "default" |
| 1279 | dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)} |
| 1280 | dgl.distributed.role.PER_ROLE_RANK["default"] = { |
| 1281 | i: i for i in range(num_clients) |
| 1282 | } |
| 1283 | |
| 1284 | for i in range(num_parts): |
| 1285 | set_roles(num_parts) |
| 1286 | part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( |
| 1287 | "/tmp/dist_graph/dist_graph_test.json", i |
| 1288 | ) |
| 1289 | local_nids = F.nonzero_1d(part_g.ndata["inner_node"]) |
| 1290 | local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids) |
| 1291 | nodes = node_split(node_mask, gpb, rank=i, force_even=True) |
| 1292 | all_nodes1.append(nodes) |
| 1293 | subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids)) |
| 1294 | print( |
| 1295 | "part {} get {} nodes and {} are in the partition".format( |
| 1296 | i, len(nodes), len(subset) |
| 1297 | ) |
| 1298 | ) |
| 1299 | |
| 1300 | set_roles(num_parts * 2) |
| 1301 | nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True) |
| 1302 | nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True) |
| 1303 | nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0)) |
| 1304 | all_nodes2.append(nodes3) |
| 1305 | subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3)) |
| 1306 | print("intersection has", len(subset)) |
| 1307 | |
| 1308 | set_roles(num_parts) |
| 1309 | local_eids = F.nonzero_1d(part_g.edata["inner_edge"]) |
| 1310 | local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids) |
no test coverage detected