MCPcopy
hub / github.com/dmlc/dgl / train_step_node

Method train_step_node

python/dgl/nn/pytorch/explain/pgexplainer.py:248–293  ·  view source on GitHub ↗

r"""Compute the loss of the explanation network for node classification Parameters ---------- nodes : int, iterable[int], tensor The nodes from the graph used to train the explanation network, which cannot have any duplicate value. graph : DGL

(self, nodes, graph, feat, temperature, **kwargs)

Source from the content-addressed store, hash-verified

246 return loss
247
248 def train_step_node(self, nodes, graph, feat, temperature, **kwargs):
249 r"""Compute the loss of the explanation network for node classification
250
251 Parameters
252 ----------
253 nodes : int, iterable[int], tensor
254 The nodes from the graph used to train the explanation network,
255 which cannot have any duplicate value.
256 graph : DGLGraph
257 Input homogeneous graph.
258 feat : Tensor
259 The input feature of shape :math:`(N, D)`. :math:`N` is the
260 number of nodes, and :math:`D` is the feature size.
261 temperature : float
262 The temperature parameter fed to the sampling procedure.
263 kwargs : dict
264 Additional arguments passed to the GNN model.
265
266 Returns
267 -------
268 Tensor
269 A scalar tensor representing the loss.
270 """
271 assert (
272 not self.graph_explanation
273 ), '"explain_graph" must be False when initializing the module.'
274
275 self.model = self.model.to(graph.device)
276 self.elayers = self.elayers.to(graph.device)
277
278 if isinstance(nodes, torch.Tensor):
279 nodes = nodes.tolist()
280 if isinstance(nodes, int):
281 nodes = [nodes]
282
283 prob, _, batched_graph, inverse_indices = self.explain_node(
284 nodes, graph, feat, temperature, training=True, **kwargs
285 )
286
287 pred = self.model(
288 batched_graph, self.batched_feats, embed=False, **kwargs
289 )
290 pred = pred.argmax(-1).data
291
292 loss = self.loss(prob[inverse_indices], pred[inverse_indices])
293 return loss
294
295 def explain_graph(
296 self, graph, feat, temperature=1.0, training=False, **kwargs

Callers 1

test_pgexplainerFunction · 0.95

Calls 3

explain_nodeMethod · 0.95
lossMethod · 0.95
toMethod · 0.45

Tested by 1

test_pgexplainerFunction · 0.76