MCPcopy
hub / github.com/THUDM/CogDL / backward

Method backward

cogdl/operators/fused_gat.py:25–41  ·  view source on GitHub ↗
(ctx, grad_out)

Source from the content-addressed store, hash-verified

23
24 @staticmethod
25 def backward(ctx, grad_out):
26 row_ptr, col_ind, col_ptr, row_ind, edge_max, edge_sum, in_feat, attn_row, attn_col = ctx.saved_tensors
27 grad_out = grad_out.contiguous()
28 grad_feat, grad_attn_row, grad_attn_col = fused_gatconv.gat_backward(
29 ctx.negative_slope,
30 row_ptr,
31 col_ind,
32 col_ptr,
33 row_ind,
34 edge_max,
35 edge_sum,
36 in_feat,
37 attn_row,
38 attn_col,
39 grad_out,
40 )
41 return grad_attn_row, grad_attn_col, None, None, None, None, None, grad_feat, None

Callers

nothing calls this directly

Calls 1

contiguousMethod · 0.80

Tested by

no test coverage detected