Sampling function. Parameters ---------- g : DGLGraph The graph to sample from. indices : Tensor or dict[str, Tensor] Nodes which induce the subgraph. exclude_eids : Tensor or dict[etype, Tensor], optional The edges to excl
(
self, g, indices, exclude_eids=None
)
| 63 | self.output_device = output_device |
| 64 | |
| 65 | def sample( |
| 66 | self, g, indices, exclude_eids=None |
| 67 | ): # pylint: disable=arguments-differ |
| 68 | """Sampling function. |
| 69 | |
| 70 | Parameters |
| 71 | ---------- |
| 72 | g : DGLGraph |
| 73 | The graph to sample from. |
| 74 | indices : Tensor or dict[str, Tensor] |
| 75 | Nodes which induce the subgraph. |
| 76 | exclude_eids : Tensor or dict[etype, Tensor], optional |
| 77 | The edges to exclude from the sampled subgraph. |
| 78 | |
| 79 | Returns |
| 80 | ------- |
| 81 | input_nodes : Tensor or dict[str, Tensor] |
| 82 | The node IDs inducing the subgraph. |
| 83 | output_nodes : Tensor or dict[str, Tensor] |
| 84 | The node IDs that are sampled in this minibatch. |
| 85 | subg : DGLGraph |
| 86 | The subgraph itself. |
| 87 | """ |
| 88 | |
| 89 | # Define empty dictionary to store reached nodes. |
| 90 | output_nodes = indices |
| 91 | all_reached_nodes = [indices] |
| 92 | |
| 93 | # Iterate over fanout. |
| 94 | for fanout in reversed(self.fanouts): |
| 95 | |
| 96 | # Sample frontier. |
| 97 | frontier = g.sample_neighbors( |
| 98 | indices, |
| 99 | fanout, |
| 100 | output_device=self.output_device, |
| 101 | replace=self.replace, |
| 102 | prob=self.prob, |
| 103 | exclude_edges=exclude_eids, |
| 104 | ) |
| 105 | |
| 106 | # Get reached nodes. |
| 107 | curr_reached = defaultdict(list) |
| 108 | for c_etype in frontier.canonical_etypes: |
| 109 | (src_type, _, _) = c_etype |
| 110 | src, _ = frontier.edges(etype=c_etype) |
| 111 | curr_reached[src_type].append(src) |
| 112 | |
| 113 | # De-duplication. |
| 114 | curr_reached = { |
| 115 | ntype: torch.unique(torch.cat(srcs)) |
| 116 | for ntype, srcs in curr_reached.items() |
| 117 | } |
| 118 | |
| 119 | # Generate type sampling probabilties. |
| 120 | type_count = { |
| 121 | node_type: indices.shape[0] |
| 122 | for node_type, indices in curr_reached.items() |
nothing calls this directly
no test coverage detected