r"""Learn and return an edge mask that plays a crucial role to explain the prediction made by the GNN for provided set of node IDs. Also, return the prediction made with the graph and edge mask. Parameters ---------- nodes : int, iterable[int], tensor
(
self, nodes, graph, feat, temperature=1.0, training=False, **kwargs
)
| 429 | return (probs, edge_mask) |
| 430 | |
| 431 | def explain_node( |
| 432 | self, nodes, graph, feat, temperature=1.0, training=False, **kwargs |
| 433 | ): |
| 434 | r"""Learn and return an edge mask that plays a crucial role to |
| 435 | explain the prediction made by the GNN for provided set of node IDs. |
| 436 | Also, return the prediction made with the graph and edge mask. |
| 437 | |
| 438 | Parameters |
| 439 | ---------- |
| 440 | nodes : int, iterable[int], tensor |
| 441 | The nodes from the graph, which cannot have any duplicate value. |
| 442 | graph : DGLGraph |
| 443 | A homogeneous graph. |
| 444 | feat : Tensor |
| 445 | The input feature of shape :math:`(N, D)`. :math:`N` is the |
| 446 | number of nodes, and :math:`D` is the feature size. |
| 447 | temperature : float |
| 448 | The temperature parameter fed to the sampling procedure. |
| 449 | training : bool |
| 450 | Training the explanation network. |
| 451 | kwargs : dict |
| 452 | Additional arguments passed to the GNN model. |
| 453 | |
| 454 | Returns |
| 455 | ------- |
| 456 | Tensor |
| 457 | Classification probabilities given the masked graph. It is a tensor |
| 458 | of shape :math:`(N, L)`, where :math:`L` is the different types |
| 459 | of node labels in the dataset, and :math:`N` is the number of nodes |
| 460 | in the graph. |
| 461 | Tensor |
| 462 | Edge weights which is a tensor of shape :math:`(E)`, where :math:`E` |
| 463 | is the number of edges in the graph. A higher weight suggests a |
| 464 | larger contribution of the edge. |
| 465 | DGLGraph |
| 466 | The batched set of subgraphs induced on the k-hop in-neighborhood |
| 467 | of the input center nodes. |
| 468 | Tensor |
| 469 | The new IDs of the subgraph center nodes. |
| 470 | |
| 471 | Examples |
| 472 | -------- |
| 473 | |
| 474 | >>> import dgl |
| 475 | >>> import numpy as np |
| 476 | >>> import torch |
| 477 | |
| 478 | >>> # Define the model |
| 479 | >>> class Model(torch.nn.Module): |
| 480 | ... def __init__(self, in_feats, out_feats): |
| 481 | ... super().__init__() |
| 482 | ... self.conv1 = dgl.nn.GraphConv(in_feats, out_feats) |
| 483 | ... self.conv2 = dgl.nn.GraphConv(out_feats, out_feats) |
| 484 | ... |
| 485 | ... def forward(self, g, h, embed=False, edge_weight=None): |
| 486 | ... h = self.conv1(g, h, edge_weight=edge_weight) |
| 487 | ... if embed: |
| 488 | ... return h |