(g, idtype, n_classes)
| 1821 | ) |
| 1822 | @pytest.mark.parametrize("n_classes", [2]) |
| 1823 | def test_pgexplainer(g, idtype, n_classes): |
| 1824 | ctx = F.ctx() |
| 1825 | g = g.astype(idtype).to(ctx) |
| 1826 | feat = F.randn((g.num_nodes(), 5)) |
| 1827 | g.ndata["attr"] = feat |
| 1828 | |
| 1829 | # add reverse edges |
| 1830 | transform = dgl.transforms.AddReverse(copy_edata=True) |
| 1831 | g = transform(g) |
| 1832 | |
| 1833 | class Model(th.nn.Module): |
| 1834 | def __init__(self, in_feats, out_feats, graph=False): |
| 1835 | super(Model, self).__init__() |
| 1836 | self.graph = graph |
| 1837 | self.conv = nn.GraphConv(in_feats, out_feats) |
| 1838 | self.fc = th.nn.Linear(out_feats, out_feats) |
| 1839 | th.nn.init.xavier_uniform_(self.fc.weight) |
| 1840 | |
| 1841 | def forward(self, g, h, embed=False, edge_weight=None): |
| 1842 | h = self.conv(g, h, edge_weight=edge_weight) |
| 1843 | |
| 1844 | if not self.graph or embed: |
| 1845 | return h |
| 1846 | |
| 1847 | with g.local_scope(): |
| 1848 | g.ndata["h"] = h |
| 1849 | hg = dgl.mean_nodes(g, "h") |
| 1850 | return self.fc(hg) |
| 1851 | |
| 1852 | # graph explainer |
| 1853 | model = Model(feat.shape[1], n_classes, graph=True) |
| 1854 | model = model.to(ctx) |
| 1855 | explainer = nn.PGExplainer(model, n_classes) |
| 1856 | explainer.train_step(g, g.ndata["attr"], 5.0) |
| 1857 | |
| 1858 | probs, edge_weight = explainer.explain_graph(g, feat) |
| 1859 | |
| 1860 | # node explainer |
| 1861 | model = Model(feat.shape[1], n_classes, graph=False) |
| 1862 | model = model.to(ctx) |
| 1863 | explainer = nn.PGExplainer( |
| 1864 | model, n_classes, num_hops=1, explain_graph=False |
| 1865 | ) |
| 1866 | explainer.train_step_node(0, g, g.ndata["attr"], 5.0) |
| 1867 | explainer.train_step_node([0, 1], g, g.ndata["attr"], 5.0) |
| 1868 | explainer.train_step_node(th.tensor(0), g, g.ndata["attr"], 5.0) |
| 1869 | explainer.train_step_node(th.tensor([0, 1]), g, g.ndata["attr"], 5.0) |
| 1870 | |
| 1871 | probs, edge_weight, bg, inverse_indices = explainer.explain_node(0, g, feat) |
| 1872 | probs, edge_weight, bg, inverse_indices = explainer.explain_node( |
| 1873 | [0, 1], g, feat |
| 1874 | ) |
| 1875 | probs, edge_weight, bg, inverse_indices = explainer.explain_node( |
| 1876 | th.tensor(0), g, feat |
| 1877 | ) |
| 1878 | probs, edge_weight, bg, inverse_indices = explainer.explain_node( |
| 1879 | th.tensor([0, 1]), g, feat |
| 1880 | ) |
nothing calls this directly
no test coverage detected