Node ID sampler for random node sampler
(self, g)
| 96 | self.output_device = output_device |
| 97 | |
| 98 | def node_sampler(self, g): |
| 99 | """Node ID sampler for random node sampler""" |
| 100 | # Alternatively, this can be realized by uniformly sampling an edge subset, |
| 101 | # and then take the src node of the sampled edges. However, the number of edges |
| 102 | # is typically much larger than the number of nodes. |
| 103 | if self.cache and self.prob is not None: |
| 104 | prob = self.prob |
| 105 | else: |
| 106 | prob = g.out_degrees().float().clamp(min=1) |
| 107 | if self.cache: |
| 108 | self.prob = prob |
| 109 | return ( |
| 110 | torch.multinomial(prob, num_samples=self.budget, replacement=True) |
| 111 | .unique() |
| 112 | .type(g.idtype) |
| 113 | ) |
| 114 | |
| 115 | def edge_sampler(self, g): |
| 116 | """Node ID sampler for random edge sampler""" |
nothing calls this directly
no test coverage detected