r"""Learn and return a node feature mask and subgraph that play a crucial role to explain the prediction made by the GNN for node :attr:`node_id`. Parameters ---------- node_id : int The node to explain. graph : DGLGraph A homo
(self, node_id, graph, feat, **kwargs)
| 163 | return loss |
| 164 | |
| 165 | def explain_node(self, node_id, graph, feat, **kwargs): |
| 166 | r"""Learn and return a node feature mask and subgraph that play a |
| 167 | crucial role to explain the prediction made by the GNN for node |
| 168 | :attr:`node_id`. |
| 169 | |
| 170 | Parameters |
| 171 | ---------- |
| 172 | node_id : int |
| 173 | The node to explain. |
| 174 | graph : DGLGraph |
| 175 | A homogeneous graph. |
| 176 | feat : Tensor |
| 177 | The input feature of shape :math:`(N, D)`. :math:`N` is the |
| 178 | number of nodes, and :math:`D` is the feature size. |
| 179 | kwargs : dict |
| 180 | Additional arguments passed to the GNN model. Tensors whose |
| 181 | first dimension is the number of nodes or edges will be |
| 182 | assumed to be node/edge features. |
| 183 | |
| 184 | Returns |
| 185 | ------- |
| 186 | new_node_id : Tensor |
| 187 | The new ID of the input center node. |
| 188 | sg : DGLGraph |
| 189 | The subgraph induced on the k-hop in-neighborhood of the input center node. |
| 190 | feat_mask : Tensor |
| 191 | Learned node feature importance mask of shape :math:`(D)`, where :math:`D` is the |
| 192 | feature size. The values are within range :math:`(0, 1)`. |
| 193 | The higher, the more important. |
| 194 | edge_mask : Tensor |
| 195 | Learned importance mask of the edges in the subgraph, which is a tensor |
| 196 | of shape :math:`(E)`, where :math:`E` is the number of edges in the |
| 197 | subgraph. The values are within range :math:`(0, 1)`. |
| 198 | The higher, the more important. |
| 199 | |
| 200 | Examples |
| 201 | -------- |
| 202 | |
| 203 | >>> import dgl |
| 204 | >>> import dgl.function as fn |
| 205 | >>> import torch |
| 206 | >>> import torch.nn as nn |
| 207 | >>> from dgl.data import CoraGraphDataset |
| 208 | >>> from dgl.nn import GNNExplainer |
| 209 | |
| 210 | >>> # Load dataset |
| 211 | >>> data = CoraGraphDataset() |
| 212 | >>> g = data[0] |
| 213 | >>> features = g.ndata['feat'] |
| 214 | >>> labels = g.ndata['label'] |
| 215 | >>> train_mask = g.ndata['train_mask'] |
| 216 | |
| 217 | >>> # Define a model |
| 218 | >>> class Model(nn.Module): |
| 219 | ... def __init__(self, in_feats, out_feats): |
| 220 | ... super(Model, self).__init__() |
| 221 | ... self.linear = nn.Linear(in_feats, out_feats) |
| 222 | ... |