MCPcopy
hub / github.com/dmlc/dgl / test_split

Function test_split

tests/distributed/test_dist_graph_store.py:1158–1248  ·  view source on GitHub ↗
(hetero, empty_mask)

Source from the content-addressed store, hash-verified

1156@pytest.mark.parametrize("hetero", [True, False])
1157@pytest.mark.parametrize("empty_mask", [True, False])
1158def test_split(hetero, empty_mask):
1159 if hetero:
1160 g = create_random_hetero()
1161 ntype = "n1"
1162 etype = "r1"
1163 else:
1164 g = create_random_graph(10000)
1165 ntype = "_N"
1166 etype = "_E"
1167 num_parts = 4
1168 num_hops = 2
1169 partition_graph(
1170 g,
1171 "dist_graph_test",
1172 num_parts,
1173 "/tmp/dist_graph",
1174 num_hops=num_hops,
1175 part_method="metis",
1176 )
1177
1178 mask_thd = 100 if empty_mask else 30
1179 node_mask = np.random.randint(0, 100, size=g.num_nodes(ntype)) > mask_thd
1180 edge_mask = np.random.randint(0, 100, size=g.num_edges(etype)) > mask_thd
1181 selected_nodes = np.nonzero(node_mask)[0]
1182 selected_edges = np.nonzero(edge_mask)[0]
1183
1184 # The code now collects the roles of all client processes and use the information
1185 # to determine how to split the workloads. Here is to simulate the multi-client
1186 # use case.
1187 def set_roles(num_clients):
1188 dgl.distributed.role.CUR_ROLE = "default"
1189 dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}
1190 dgl.distributed.role.PER_ROLE_RANK["default"] = {
1191 i: i for i in range(num_clients)
1192 }
1193
1194 for i in range(num_parts):
1195 set_roles(num_parts)
1196 part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
1197 "/tmp/dist_graph/dist_graph_test.json", i
1198 )
1199 local_nids = F.nonzero_1d(part_g.ndata["inner_node"])
1200 local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
1201 if hetero:
1202 ntype_ids, nids = gpb.map_to_per_ntype(local_nids)
1203 local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]
1204 else:
1205 local_nids = F.asnumpy(local_nids)
1206 nodes1 = np.intersect1d(selected_nodes, local_nids)
1207 nodes2 = node_split(
1208 node_mask, gpb, ntype=ntype, rank=i, force_even=False
1209 )
1210 assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
1211 for n in F.asnumpy(nodes2):
1212 assert n in local_nids
1213
1214 set_roles(num_parts * 2)
1215 nodes3 = node_split(

Callers 1

Calls 13

create_random_graphFunction · 0.90
partition_graphFunction · 0.90
load_partitionFunction · 0.90
node_splitFunction · 0.90
edge_splitFunction · 0.90
set_rolesFunction · 0.85
nonzeroMethod · 0.80
asnumpyMethod · 0.80
create_random_heteroFunction · 0.70
num_nodesMethod · 0.45
num_edgesMethod · 0.45
map_to_per_ntypeMethod · 0.45

Tested by

no test coverage detected