MCPcopy
hub / github.com/dmlc/dgl / forward

Method forward

examples/pytorch/hgp_sl/functions.py:131–155  ·  view source on GitHub ↗
(
        ctx,
        gidx: HeteroGraphIndex,
        scores: Tensor,
        eids: Tensor,
        end_n_ids: Tensor,
        norm_by: str,
    )

Source from the content-addressed store, hash-verified

129
130 @staticmethod
131 def forward(
132 ctx,
133 gidx: HeteroGraphIndex,
134 scores: Tensor,
135 eids: Tensor,
136 end_n_ids: Tensor,
137 norm_by: str,
138 ):
139 if not is_all(eids):
140 gidx = gidx.edge_subgraph([eids], True).graph
141 if norm_by == "src":
142 gidx = gidx.reverse()
143
144 # use feat - max(feat) for numerical stability.
145 scores = scores.float()
146 scores_max = _gspmm(gidx, "copy_rhs", "max", None, scores)[0]
147 scores = _gsddmm(gidx, "sub", scores, scores_max, "e", "v")
148
149 # find threshold for each node and perform ReLU(u-t(u)) operation.
150 tau, supp_size = _threshold_and_support_graph(gidx, scores, end_n_ids)
151 out = torch.clamp(_gsddmm(gidx, "sub", scores, tau, "e", "v"), min=0)
152 ctx.backward_cache = gidx
153 ctx.save_for_backward(supp_size, out)
154 torch.cuda.empty_cache()
155 return out
156
157 @staticmethod
158 def backward(ctx, grad_out):

Callers

nothing calls this directly

Calls 7

is_allFunction · 0.90
_gspmmFunction · 0.90
_gsddmmFunction · 0.90
reverseMethod · 0.80
edge_subgraphMethod · 0.45
floatMethod · 0.45

Tested by

no test coverage detected