MCPcopy
hub / github.com/dmlc/dgl / explain_node

Method explain_node

python/dgl/nn/pytorch/explain/gnnexplainer.py:165–317  ·  view source on GitHub ↗

r"""Learn and return a node feature mask and subgraph that play a crucial role to explain the prediction made by the GNN for node :attr:`node_id`. Parameters ---------- node_id : int The node to explain. graph : DGLGraph A homo

(self, node_id, graph, feat, **kwargs)

Source from the content-addressed store, hash-verified

163 return loss
164
165 def explain_node(self, node_id, graph, feat, **kwargs):
166 r"""Learn and return a node feature mask and subgraph that play a
167 crucial role to explain the prediction made by the GNN for node
168 :attr:`node_id`.
169
170 Parameters
171 ----------
172 node_id : int
173 The node to explain.
174 graph : DGLGraph
175 A homogeneous graph.
176 feat : Tensor
177 The input feature of shape :math:`(N, D)`. :math:`N` is the
178 number of nodes, and :math:`D` is the feature size.
179 kwargs : dict
180 Additional arguments passed to the GNN model. Tensors whose
181 first dimension is the number of nodes or edges will be
182 assumed to be node/edge features.
183
184 Returns
185 -------
186 new_node_id : Tensor
187 The new ID of the input center node.
188 sg : DGLGraph
189 The subgraph induced on the k-hop in-neighborhood of the input center node.
190 feat_mask : Tensor
191 Learned node feature importance mask of shape :math:`(D)`, where :math:`D` is the
192 feature size. The values are within range :math:`(0, 1)`.
193 The higher, the more important.
194 edge_mask : Tensor
195 Learned importance mask of the edges in the subgraph, which is a tensor
196 of shape :math:`(E)`, where :math:`E` is the number of edges in the
197 subgraph. The values are within range :math:`(0, 1)`.
198 The higher, the more important.
199
200 Examples
201 --------
202
203 >>> import dgl
204 >>> import dgl.function as fn
205 >>> import torch
206 >>> import torch.nn as nn
207 >>> from dgl.data import CoraGraphDataset
208 >>> from dgl.nn import GNNExplainer
209
210 >>> # Load dataset
211 >>> data = CoraGraphDataset()
212 >>> g = data[0]
213 >>> features = g.ndata['feat']
214 >>> labels = g.ndata['label']
215 >>> train_mask = g.ndata['train_mask']
216
217 >>> # Define a model
218 >>> class Model(nn.Module):
219 ... def __init__(self, in_feats, out_feats):
220 ... super(Model, self).__init__()
221 ... self.linear = nn.Linear(in_feats, out_feats)
222 ...

Callers 2

test_gnnexplainerFunction · 0.95
mainFunction · 0.95

Calls 14

_init_masksMethod · 0.95
_loss_regularizeMethod · 0.95
khop_in_subgraphFunction · 0.85
toMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45
longMethod · 0.45
itemsMethod · 0.45
sizeMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by 1

test_gnnexplainerFunction · 0.76