MCPcopy
hub / github.com/dmlc/dgl / test_get_node_partition_from_book

Function test_get_node_partition_from_book

tests/python/common/test_partition.py:15–55  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

13)
14@parametrize_idtype
15def test_get_node_partition_from_book(idtype):
16 node_map = {"_N": F.tensor([[0, 3], [4, 5], [6, 10]], dtype=idtype)}
17 edge_map = {
18 ("_N", "_E", "_N"): F.tensor([[0, 9], [10, 15], [16, 25]], dtype=idtype)
19 }
20 ntypes = {ntype: i for i, ntype in enumerate(node_map)}
21 etypes = {etype: i for i, etype in enumerate(edge_map)}
22 book = gpb.RangePartitionBook(0, 3, node_map, edge_map, ntypes, etypes)
23 partition = gpb.get_node_partition_from_book(book, F.ctx())
24 assert partition.num_parts() == 3
25 assert partition.array_size() == 11
26
27 # Test map_to_local
28 test_ids = F.copy_to(F.tensor([0, 2, 6, 7, 10], dtype=idtype), F.ctx())
29 act_ids = partition.map_to_local(test_ids)
30 exp_ids = F.copy_to(F.tensor([0, 2, 0, 1, 4], dtype=idtype), F.ctx())
31 assert F.array_equal(act_ids, exp_ids)
32
33 # Test map_to_global
34 test_ids = F.copy_to(F.tensor([0, 2], dtype=idtype), F.ctx())
35 act_ids = partition.map_to_global(test_ids, 0)
36 exp_ids = F.copy_to(F.tensor([0, 2], dtype=idtype), F.ctx())
37 assert F.array_equal(act_ids, exp_ids)
38
39 test_ids = F.copy_to(F.tensor([0, 1], dtype=idtype), F.ctx())
40 act_ids = partition.map_to_global(test_ids, 1)
41 exp_ids = F.copy_to(F.tensor([4, 5], dtype=idtype), F.ctx())
42 assert F.array_equal(act_ids, exp_ids)
43
44 test_ids = F.copy_to(F.tensor([0, 1, 4], dtype=idtype), F.ctx())
45 act_ids = partition.map_to_global(test_ids, 2)
46 exp_ids = F.copy_to(F.tensor([6, 7, 10], dtype=idtype), F.ctx())
47 assert F.array_equal(act_ids, exp_ids)
48
49 # Test generate_permutation
50 test_ids = F.copy_to(F.tensor([6, 0, 7, 2, 10], dtype=idtype), F.ctx())
51 perm, split_sum = partition.generate_permutation(test_ids)
52 exp_perm = F.copy_to(F.tensor([1, 3, 0, 2, 4], dtype=idtype), F.ctx())
53 exp_sum = F.copy_to(F.tensor([2, 0, 3]), F.ctx())
54 assert F.array_equal(perm, exp_perm)
55 assert F.array_equal(split_sum, exp_sum)

Callers

nothing calls this directly

Calls 7

num_partsMethod · 0.80
array_sizeMethod · 0.80
map_to_localMethod · 0.80
map_to_globalMethod · 0.80
generate_permutationMethod · 0.80
ctxMethod · 0.45
copy_toMethod · 0.45

Tested by

no test coverage detected