MCPcopy
hub / github.com/dmlc/dgl / forward

Method forward

python/dgl/nn/pytorch/utils.py:492–556  ·  view source on GitHub ↗

r"""Compute the label propagation process. Parameters ---------- g : DGLGraph The input graph. labels : torch.Tensor The input node labels. There are three cases supported. * A LongTensor of shape :math:`(N, 1)` or :math:`(N,)` fo

(self, g, labels, mask=None)

Source from the content-addressed store, hash-verified

490 self.reset = reset
491
492 def forward(self, g, labels, mask=None):
493 r"""Compute the label propagation process.
494
495 Parameters
496 ----------
497 g : DGLGraph
498 The input graph.
499 labels : torch.Tensor
500 The input node labels. There are three cases supported.
501
502 * A LongTensor of shape :math:`(N, 1)` or :math:`(N,)` for node class labels in
503 multiclass classification, where :math:`N` is the number of nodes.
504 * A LongTensor of shape :math:`(N, C)` for one-hot encoding of node class labels
505 in multiclass classification, where :math:`C` is the number of classes.
506 * A LongTensor of shape :math:`(N, L)` for node labels in multilabel binary
507 classification, where :math:`L` is the number of labels.
508 mask : torch.Tensor
509 The bool indicators of shape :math:`(N,)` with True denoting labeled nodes.
510 Default: None, indicating all nodes are labeled.
511
512 Returns
513 -------
514 torch.Tensor
515 The propagated node labels of shape :math:`(N, D)` with float type, where :math:`D`
516 is the number of classes or labels.
517 """
518 with g.local_scope():
519 # multi-label / multi-class
520 if len(labels.size()) > 1 and labels.size(1) > 1:
521 labels = labels.to(th.float32)
522 # single-label multi-class
523 else:
524 labels = F.one_hot(labels.view(-1)).to(th.float32)
525
526 y = labels
527 if mask is not None:
528 y = th.zeros_like(labels)
529 y[mask] = labels[mask]
530
531 init = (1 - self.alpha) * y
532 in_degs = g.in_degrees().float().clamp(min=1)
533 out_degs = g.out_degrees().float().clamp(min=1)
534 if self.norm_type == "sym":
535 norm_i = th.pow(in_degs, -0.5).to(labels.device).unsqueeze(1)
536 norm_j = th.pow(out_degs, -0.5).to(labels.device).unsqueeze(1)
537 elif self.norm_type == "row":
538 norm_i = th.pow(in_degs, -1.0).to(labels.device).unsqueeze(1)
539 else:
540 raise ValueError(
541 f"Expect norm_type to be 'sym' or 'row', got {self.norm_type}"
542 )
543
544 for _ in range(self.k):
545 g.ndata["h"] = y * norm_j if self.norm_type == "sym" else y
546 g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
547 y = init + self.alpha * g.ndata["h"] * norm_i
548
549 if self.clamp:

Callers

nothing calls this directly

Calls 8

local_scopeMethod · 0.80
update_allMethod · 0.80
sizeMethod · 0.45
toMethod · 0.45
floatMethod · 0.45
in_degreesMethod · 0.45
out_degreesMethod · 0.45
normalizeMethod · 0.45

Tested by

no test coverage detected