MCPcopy
hub / github.com/dmlc/dgl / explain_node

Method explain_node

python/dgl/nn/pytorch/explain/pgexplainer.py:431–606  ·  view source on GitHub ↗

r"""Learn and return an edge mask that plays a crucial role to explain the prediction made by the GNN for provided set of node IDs. Also, return the prediction made with the graph and edge mask. Parameters ---------- nodes : int, iterable[int], tensor

(
        self, nodes, graph, feat, temperature=1.0, training=False, **kwargs
    )

Source from the content-addressed store, hash-verified

429 return (probs, edge_mask)
430
431 def explain_node(
432 self, nodes, graph, feat, temperature=1.0, training=False, **kwargs
433 ):
434 r"""Learn and return an edge mask that plays a crucial role to
435 explain the prediction made by the GNN for provided set of node IDs.
436 Also, return the prediction made with the graph and edge mask.
437
438 Parameters
439 ----------
440 nodes : int, iterable[int], tensor
441 The nodes from the graph, which cannot have any duplicate value.
442 graph : DGLGraph
443 A homogeneous graph.
444 feat : Tensor
445 The input feature of shape :math:`(N, D)`. :math:`N` is the
446 number of nodes, and :math:`D` is the feature size.
447 temperature : float
448 The temperature parameter fed to the sampling procedure.
449 training : bool
450 Training the explanation network.
451 kwargs : dict
452 Additional arguments passed to the GNN model.
453
454 Returns
455 -------
456 Tensor
457 Classification probabilities given the masked graph. It is a tensor
458 of shape :math:`(N, L)`, where :math:`L` is the different types
459 of node labels in the dataset, and :math:`N` is the number of nodes
460 in the graph.
461 Tensor
462 Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`
463 is the number of edges in the graph. A higher weight suggests a
464 larger contribution of the edge.
465 DGLGraph
466 The batched set of subgraphs induced on the k-hop in-neighborhood
467 of the input center nodes.
468 Tensor
469 The new IDs of the subgraph center nodes.
470
471 Examples
472 --------
473
474 >>> import dgl
475 >>> import numpy as np
476 >>> import torch
477
478 >>> # Define the model
479 >>> class Model(torch.nn.Module):
480 ... def __init__(self, in_feats, out_feats):
481 ... super().__init__()
482 ... self.conv1 = dgl.nn.GraphConv(in_feats, out_feats)
483 ... self.conv2 = dgl.nn.GraphConv(out_feats, out_feats)
484 ...
485 ... def forward(self, g, h, embed=False, edge_weight=None):
486 ... h = self.conv1(g, h, edge_weight=edge_weight)
487 ... if embed:
488 ... return h

Callers 2

train_step_nodeMethod · 0.95
test_pgexplainerFunction · 0.95

Calls 13

concrete_sampleMethod · 0.95
set_masksMethod · 0.95
clear_masksMethod · 0.95
batchFunction · 0.90
khop_in_subgraphFunction · 0.85
appendMethod · 0.80
nonzeroMethod · 0.80
toMethod · 0.45
longMethod · 0.45
nodesMethod · 0.45
edgesMethod · 0.45
num_edgesMethod · 0.45

Tested by 1

test_pgexplainerFunction · 0.76