MCPcopy
hub / github.com/dmlc/dgl / test_split_even

Function test_split_even

tests/distributed/test_dist_graph_store.py:1253–1336  ·  view source on GitHub ↗
(empty_mask)

Source from the content-addressed store, hash-verified

1251@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
1252@pytest.mark.parametrize("empty_mask", [True, False])
1253def test_split_even(empty_mask):
1254 g = create_random_graph(10000)
1255 num_parts = 4
1256 num_hops = 2
1257 partition_graph(
1258 g,
1259 "dist_graph_test",
1260 num_parts,
1261 "/tmp/dist_graph",
1262 num_hops=num_hops,
1263 part_method="metis",
1264 )
1265
1266 mask_thd = 100 if empty_mask else 30
1267 node_mask = np.random.randint(0, 100, size=g.num_nodes()) > mask_thd
1268 edge_mask = np.random.randint(0, 100, size=g.num_edges()) > mask_thd
1269 all_nodes1 = []
1270 all_nodes2 = []
1271 all_edges1 = []
1272 all_edges2 = []
1273
1274 # The code now collects the roles of all client processes and use the information
1275 # to determine how to split the workloads. Here is to simulate the multi-client
1276 # use case.
1277 def set_roles(num_clients):
1278 dgl.distributed.role.CUR_ROLE = "default"
1279 dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
1280 dgl.distributed.role.PER_ROLE_RANK["default"] = {
1281 i: i for i in range(num_clients)
1282 }
1283
1284 for i in range(num_parts):
1285 set_roles(num_parts)
1286 part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
1287 "/tmp/dist_graph/dist_graph_test.json", i
1288 )
1289 local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1290 local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1291 nodes = node_split(node_mask, gpb, rank=i, force_even=True)
1292 all_nodes1.append(nodes)
1293 subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))
1294 print(
1295 "part {} get {} nodes and {} are in the partition".format(
1296 i, len(nodes), len(subset)
1297 )
1298 )
1299
1300 set_roles(num_parts * 2)
1301 nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)
1302 nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)
1303 nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))
1304 all_nodes2.append(nodes3)
1305 subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))
1306 print("intersection has", len(subset))
1307
1308 set_roles(num_parts)
1309 local_eids = F.nonzero_1d(part_g.edata["inner_edge"])
1310 local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)

Callers 1

Calls 12

create_random_graphFunction · 0.90
partition_graphFunction · 0.90
load_partitionFunction · 0.90
node_splitFunction · 0.90
edge_splitFunction · 0.90
set_rolesFunction · 0.85
appendMethod · 0.80
asnumpyMethod · 0.80
formatMethod · 0.80
nonzeroMethod · 0.80
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected