r"""Get the threshold for graph sparsification.
(self, num_nodes, mat)
| 1457 | self.avg_degree = avg_degree |
| 1458 | |
| 1459 | def get_eps(self, num_nodes, mat): |
| 1460 | r"""Get the threshold for graph sparsification.""" |
| 1461 | if self.eps is None: |
| 1462 | # Infer from self.avg_degree |
| 1463 | if self.avg_degree > num_nodes: |
| 1464 | return float("-inf") |
| 1465 | sorted_weights = torch.sort(mat.flatten(), descending=True).values |
| 1466 | return sorted_weights[self.avg_degree * num_nodes - 1] |
| 1467 | else: |
| 1468 | return self.eps |
| 1469 | |
| 1470 | def __call__(self, g): |
| 1471 | # Step1: diffusion |