r"""Get the threshold for graph sparsification.
(self, num_nodes, mat)
| 1364 | self.avg_degree = avg_degree |
| 1365 | |
| 1366 | def get_eps(self, num_nodes, mat): |
| 1367 | r"""Get the threshold for graph sparsification.""" |
| 1368 | if self.eps is None: |
| 1369 | # Infer from self.avg_degree |
| 1370 | if self.avg_degree > num_nodes: |
| 1371 | return float("-inf") |
| 1372 | sorted_weights = torch.sort(mat.flatten(), descending=True).values |
| 1373 | return sorted_weights[self.avg_degree * num_nodes - 1] |
| 1374 | else: |
| 1375 | return self.eps |
| 1376 | |
| 1377 | def __call__(self, g): |
| 1378 | # Step1: heat kernel diffusion |