(num_parts, use_graphbolt, prob_or_mask)
| 1764 | @pytest.mark.parametrize("use_graphbolt", [False]) |
| 1765 | @pytest.mark.parametrize("prob_or_mask", ["prob", "mask"]) |
| 1766 | def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask): |
| 1767 | reset_envs() |
| 1768 | os.environ["DGL_DIST_MODE"] = "distributed" |
| 1769 | with tempfile.TemporaryDirectory() as test_dir: |
| 1770 | g = create_random_hetero() |
| 1771 | for c_etype in g.canonical_etypes: |
| 1772 | prob = torch.rand(g.num_edges(c_etype)) |
| 1773 | mask = prob > 0.2 |
| 1774 | prob[torch.randperm(len(prob))[: int(len(prob) * 0.5)]] = 0.0 |
| 1775 | g.edges[c_etype].data["prob"] = prob |
| 1776 | g.edges[c_etype].data["mask"] = mask |
| 1777 | graph_name = "test_local_sampling" |
| 1778 | |
| 1779 | _, orig_eids = partition_graph( |
| 1780 | g, |
| 1781 | graph_name, |
| 1782 | num_parts, |
| 1783 | test_dir, |
| 1784 | num_hops=1, |
| 1785 | part_method="metis", |
| 1786 | return_mapping=True, |
| 1787 | use_graphbolt=use_graphbolt, |
| 1788 | store_eids=True, |
| 1789 | store_inner_node=True, |
| 1790 | store_inner_edge=True, |
| 1791 | ) |
| 1792 | |
| 1793 | part_config = os.path.join(test_dir, f"{graph_name}.json") |
| 1794 | for part_id in range(num_parts): |
| 1795 | local_g, _, edge_feats, gpb, _, _, _ = load_partition( |
| 1796 | part_config, |
| 1797 | part_id, |
| 1798 | load_feats=True, |
| 1799 | use_graphbolt=use_graphbolt, |
| 1800 | ) |
| 1801 | inner_global_nids = [ |
| 1802 | gpb.map_to_homo_nid(gpb.partid2nids(part_id, ntype), ntype) |
| 1803 | for ntype in gpb.ntypes |
| 1804 | ] |
| 1805 | inner_global_nids = torch.cat(inner_global_nids) |
| 1806 | inner_global_eids = { |
| 1807 | c_etype: gpb.partid2eids(part_id, c_etype) |
| 1808 | for c_etype in gpb.canonical_etypes |
| 1809 | } |
| 1810 | inner_node_data = ( |
| 1811 | local_g.node_attributes["inner_node"] |
| 1812 | if use_graphbolt |
| 1813 | else local_g.ndata["inner_node"] |
| 1814 | ) |
| 1815 | inner_edge_data = ( |
| 1816 | local_g.edge_attributes["inner_edge"] |
| 1817 | if use_graphbolt |
| 1818 | else local_g.edata["inner_edge"] |
| 1819 | ) |
| 1820 | assert len(inner_global_nids) == inner_node_data.sum() |
| 1821 | num_inner_global_eids = sum( |
| 1822 | [len(eids) for eids in inner_global_eids.values()] |
| 1823 | ) |
nothing calls this directly
no test coverage detected