(
ctx,
gidx: HeteroGraphIndex,
scores: Tensor,
eids: Tensor,
end_n_ids: Tensor,
norm_by: str,
)
| 129 | |
| 130 | @staticmethod |
| 131 | def forward( |
| 132 | ctx, |
| 133 | gidx: HeteroGraphIndex, |
| 134 | scores: Tensor, |
| 135 | eids: Tensor, |
| 136 | end_n_ids: Tensor, |
| 137 | norm_by: str, |
| 138 | ): |
| 139 | if not is_all(eids): |
| 140 | gidx = gidx.edge_subgraph([eids], True).graph |
| 141 | if norm_by == "src": |
| 142 | gidx = gidx.reverse() |
| 143 | |
| 144 | # use feat - max(feat) for numerical stability. |
| 145 | scores = scores.float() |
| 146 | scores_max = _gspmm(gidx, "copy_rhs", "max", None, scores)[0] |
| 147 | scores = _gsddmm(gidx, "sub", scores, scores_max, "e", "v") |
| 148 | |
| 149 | # find threshold for each node and perform ReLU(u-t(u)) operation. |
| 150 | tau, supp_size = _threshold_and_support_graph(gidx, scores, end_n_ids) |
| 151 | out = torch.clamp(_gsddmm(gidx, "sub", scores, tau, "e", "v"), min=0) |
| 152 | ctx.backward_cache = gidx |
| 153 | ctx.save_for_backward(supp_size, out) |
| 154 | torch.cuda.empty_cache() |
| 155 | return out |
| 156 | |
| 157 | @staticmethod |
| 158 | def backward(ctx, grad_out): |
nothing calls this directly
no test coverage detected