r"""Compute the label propagation process. Parameters ---------- g : DGLGraph The input graph. labels : torch.Tensor The input node labels. There are three cases supported. * A LongTensor of shape :math:`(N, 1)` or :math:`(N,)` fo
(self, g, labels, mask=None)
| 490 | self.reset = reset |
| 491 | |
| 492 | def forward(self, g, labels, mask=None): |
| 493 | r"""Compute the label propagation process. |
| 494 | |
| 495 | Parameters |
| 496 | ---------- |
| 497 | g : DGLGraph |
| 498 | The input graph. |
| 499 | labels : torch.Tensor |
| 500 | The input node labels. There are three cases supported. |
| 501 | |
| 502 | * A LongTensor of shape :math:`(N, 1)` or :math:`(N,)` for node class labels in |
| 503 | multiclass classification, where :math:`N` is the number of nodes. |
| 504 | * A LongTensor of shape :math:`(N, C)` for one-hot encoding of node class labels |
| 505 | in multiclass classification, where :math:`C` is the number of classes. |
| 506 | * A LongTensor of shape :math:`(N, L)` for node labels in multilabel binary |
| 507 | classification, where :math:`L` is the number of labels. |
| 508 | mask : torch.Tensor |
| 509 | The bool indicators of shape :math:`(N,)` with True denoting labeled nodes. |
| 510 | Default: None, indicating all nodes are labeled. |
| 511 | |
| 512 | Returns |
| 513 | ------- |
| 514 | torch.Tensor |
| 515 | The propagated node labels of shape :math:`(N, D)` with float type, where :math:`D` |
| 516 | is the number of classes or labels. |
| 517 | """ |
| 518 | with g.local_scope(): |
| 519 | # multi-label / multi-class |
| 520 | if len(labels.size()) > 1 and labels.size(1) > 1: |
| 521 | labels = labels.to(th.float32) |
| 522 | # single-label multi-class |
| 523 | else: |
| 524 | labels = F.one_hot(labels.view(-1)).to(th.float32) |
| 525 | |
| 526 | y = labels |
| 527 | if mask is not None: |
| 528 | y = th.zeros_like(labels) |
| 529 | y[mask] = labels[mask] |
| 530 | |
| 531 | init = (1 - self.alpha) * y |
| 532 | in_degs = g.in_degrees().float().clamp(min=1) |
| 533 | out_degs = g.out_degrees().float().clamp(min=1) |
| 534 | if self.norm_type == "sym": |
| 535 | norm_i = th.pow(in_degs, -0.5).to(labels.device).unsqueeze(1) |
| 536 | norm_j = th.pow(out_degs, -0.5).to(labels.device).unsqueeze(1) |
| 537 | elif self.norm_type == "row": |
| 538 | norm_i = th.pow(in_degs, -1.0).to(labels.device).unsqueeze(1) |
| 539 | else: |
| 540 | raise ValueError( |
| 541 | f"Expect norm_type to be 'sym' or 'row', got {self.norm_type}" |
| 542 | ) |
| 543 | |
| 544 | for _ in range(self.k): |
| 545 | g.ndata["h"] = y * norm_j if self.norm_type == "sym" else y |
| 546 | g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) |
| 547 | y = init + self.alpha * g.ndata["h"] * norm_i |
| 548 | |
| 549 | if self.clamp: |
nothing calls this directly
no test coverage detected