MCPcopy
hub / github.com/dmlc/dgl / test_pgexplainer

Function test_pgexplainer

tests/python/pytorch/nn/test_nn.py:1823–1880  ·  view source on GitHub ↗
(g, idtype, n_classes)

Source from the content-addressed store, hash-verified

1821)
1822@pytest.mark.parametrize("n_classes", [2])
1823def test_pgexplainer(g, idtype, n_classes):
1824 ctx = F.ctx()
1825 g = g.astype(idtype).to(ctx)
1826 feat = F.randn((g.num_nodes(), 5))
1827 g.ndata["attr"] = feat
1828
1829 # add reverse edges
1830 transform = dgl.transforms.AddReverse(copy_edata=True)
1831 g = transform(g)
1832
1833 class Model(th.nn.Module):
1834 def __init__(self, in_feats, out_feats, graph=False):
1835 super(Model, self).__init__()
1836 self.graph = graph
1837 self.conv = nn.GraphConv(in_feats, out_feats)
1838 self.fc = th.nn.Linear(out_feats, out_feats)
1839 th.nn.init.xavier_uniform_(self.fc.weight)
1840
1841 def forward(self, g, h, embed=False, edge_weight=None):
1842 h = self.conv(g, h, edge_weight=edge_weight)
1843
1844 if not self.graph or embed:
1845 return h
1846
1847 with g.local_scope():
1848 g.ndata["h"] = h
1849 hg = dgl.mean_nodes(g, "h")
1850 return self.fc(hg)
1851
1852 # graph explainer
1853 model = Model(feat.shape[1], n_classes, graph=True)
1854 model = model.to(ctx)
1855 explainer = nn.PGExplainer(model, n_classes)
1856 explainer.train_step(g, g.ndata["attr"], 5.0)
1857
1858 probs, edge_weight = explainer.explain_graph(g, feat)
1859
1860 # node explainer
1861 model = Model(feat.shape[1], n_classes, graph=False)
1862 model = model.to(ctx)
1863 explainer = nn.PGExplainer(
1864 model, n_classes, num_hops=1, explain_graph=False
1865 )
1866 explainer.train_step_node(0, g, g.ndata["attr"], 5.0)
1867 explainer.train_step_node([0, 1], g, g.ndata["attr"], 5.0)
1868 explainer.train_step_node(th.tensor(0), g, g.ndata["attr"], 5.0)
1869 explainer.train_step_node(th.tensor([0, 1]), g, g.ndata["attr"], 5.0)
1870
1871 probs, edge_weight, bg, inverse_indices = explainer.explain_node(0, g, feat)
1872 probs, edge_weight, bg, inverse_indices = explainer.explain_node(
1873 [0, 1], g, feat
1874 )
1875 probs, edge_weight, bg, inverse_indices = explainer.explain_node(
1876 th.tensor(0), g, feat
1877 )
1878 probs, edge_weight, bg, inverse_indices = explainer.explain_node(
1879 th.tensor([0, 1]), g, feat
1880 )

Callers

nothing calls this directly

Calls 10

train_stepMethod · 0.95
explain_graphMethod · 0.95
train_step_nodeMethod · 0.95
explain_nodeMethod · 0.95
transformFunction · 0.85
ModelClass · 0.70
ctxMethod · 0.45
toMethod · 0.45
astypeMethod · 0.45
num_nodesMethod · 0.45

Tested by

no test coverage detected