r"""Compute the loss of the explanation network for node classification Parameters ---------- nodes : int, iterable[int], tensor The nodes from the graph used to train the explanation network, which cannot have any duplicate value. graph : DGL
(self, nodes, graph, feat, temperature, **kwargs)
| 246 | return loss |
| 247 | |
| 248 | def train_step_node(self, nodes, graph, feat, temperature, **kwargs): |
| 249 | r"""Compute the loss of the explanation network for node classification |
| 250 | |
| 251 | Parameters |
| 252 | ---------- |
| 253 | nodes : int, iterable[int], tensor |
| 254 | The nodes from the graph used to train the explanation network, |
| 255 | which cannot have any duplicate value. |
| 256 | graph : DGLGraph |
| 257 | Input homogeneous graph. |
| 258 | feat : Tensor |
| 259 | The input feature of shape :math:`(N, D)`. :math:`N` is the |
| 260 | number of nodes, and :math:`D` is the feature size. |
| 261 | temperature : float |
| 262 | The temperature parameter fed to the sampling procedure. |
| 263 | kwargs : dict |
| 264 | Additional arguments passed to the GNN model. |
| 265 | |
| 266 | Returns |
| 267 | ------- |
| 268 | Tensor |
| 269 | A scalar tensor representing the loss. |
| 270 | """ |
| 271 | assert ( |
| 272 | not self.graph_explanation |
| 273 | ), '"explain_graph" must be False when initializing the module.' |
| 274 | |
| 275 | self.model = self.model.to(graph.device) |
| 276 | self.elayers = self.elayers.to(graph.device) |
| 277 | |
| 278 | if isinstance(nodes, torch.Tensor): |
| 279 | nodes = nodes.tolist() |
| 280 | if isinstance(nodes, int): |
| 281 | nodes = [nodes] |
| 282 | |
| 283 | prob, _, batched_graph, inverse_indices = self.explain_node( |
| 284 | nodes, graph, feat, temperature, training=True, **kwargs |
| 285 | ) |
| 286 | |
| 287 | pred = self.model( |
| 288 | batched_graph, self.batched_feats, embed=False, **kwargs |
| 289 | ) |
| 290 | pred = pred.argmax(-1).data |
| 291 | |
| 292 | loss = self.loss(prob[inverse_indices], pred[inverse_indices]) |
| 293 | return loss |
| 294 | |
| 295 | def explain_graph( |
| 296 | self, graph, feat, temperature=1.0, training=False, **kwargs |