MCPcopy
hub / github.com/dmlc/dgl / test_local_sampling_heterograph

Function test_local_sampling_heterograph

tests/distributed/test_distributed_sampling.py:1766–1858  ·  view source on GitHub ↗
(num_parts, use_graphbolt, prob_or_mask)

Source from the content-addressed store, hash-verified

1764@pytest.mark.parametrize("use_graphbolt", [False])
1765@pytest.mark.parametrize("prob_or_mask", ["prob", "mask"])
1766def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask):
1767 reset_envs()
1768 os.environ["DGL_DIST_MODE"] = "distributed"
1769 with tempfile.TemporaryDirectory() as test_dir:
1770 g = create_random_hetero()
1771 for c_etype in g.canonical_etypes:
1772 prob = torch.rand(g.num_edges(c_etype))
1773 mask = prob > 0.2
1774 prob[torch.randperm(len(prob))[: int(len(prob) * 0.5)]] = 0.0
1775 g.edges[c_etype].data["prob"] = prob
1776 g.edges[c_etype].data["mask"] = mask
1777 graph_name = "test_local_sampling"
1778
1779 _, orig_eids = partition_graph(
1780 g,
1781 graph_name,
1782 num_parts,
1783 test_dir,
1784 num_hops=1,
1785 part_method="metis",
1786 return_mapping=True,
1787 use_graphbolt=use_graphbolt,
1788 store_eids=True,
1789 store_inner_node=True,
1790 store_inner_edge=True,
1791 )
1792
1793 part_config = os.path.join(test_dir, f"{graph_name}.json")
1794 for part_id in range(num_parts):
1795 local_g, _, edge_feats, gpb, _, _, _ = load_partition(
1796 part_config,
1797 part_id,
1798 load_feats=True,
1799 use_graphbolt=use_graphbolt,
1800 )
1801 inner_global_nids = [
1802 gpb.map_to_homo_nid(gpb.partid2nids(part_id, ntype), ntype)
1803 for ntype in gpb.ntypes
1804 ]
1805 inner_global_nids = torch.cat(inner_global_nids)
1806 inner_global_eids = {
1807 c_etype: gpb.partid2eids(part_id, c_etype)
1808 for c_etype in gpb.canonical_etypes
1809 }
1810 inner_node_data = (
1811 local_g.node_attributes["inner_node"]
1812 if use_graphbolt
1813 else local_g.ndata["inner_node"]
1814 )
1815 inner_edge_data = (
1816 local_g.edge_attributes["inner_edge"]
1817 if use_graphbolt
1818 else local_g.edata["inner_edge"]
1819 )
1820 assert len(inner_global_nids) == inner_node_data.sum()
1821 num_inner_global_eids = sum(
1822 [len(eids) for eids in inner_global_eids.values()]
1823 )

Callers

nothing calls this directly

Calls 15

reset_envsFunction · 0.90
partition_graphFunction · 0.90
load_partitionFunction · 0.90
_etype_tuple_to_strFunction · 0.90
appendMethod · 0.80
nonzeroMethod · 0.80
create_random_heteroFunction · 0.70
sumFunction · 0.50
num_edgesMethod · 0.45
joinMethod · 0.45
map_to_homo_nidMethod · 0.45
partid2nidsMethod · 0.45

Tested by

no test coverage detected