r"""Get the threshold for graph sparsification.
(self, num_nodes, mat)
| 1252 | self.avg_degree = avg_degree |
| 1253 | |
| 1254 | def get_eps(self, num_nodes, mat): |
| 1255 | r"""Get the threshold for graph sparsification.""" |
| 1256 | if self.eps is None: |
| 1257 | # Infer from self.avg_degree |
| 1258 | if self.avg_degree > num_nodes: |
| 1259 | return float("-inf") |
| 1260 | sorted_weights = torch.sort(mat.flatten(), descending=True).values |
| 1261 | return sorted_weights[self.avg_degree * num_nodes - 1] |
| 1262 | else: |
| 1263 | return self.eps |
| 1264 | |
| 1265 | def __call__(self, g): |
| 1266 | # Step1: PPR diffusion |