(ctx, grad_out)
| 23 | |
| 24 | @staticmethod |
| 25 | def backward(ctx, grad_out): |
| 26 | row_ptr, col_ind, col_ptr, row_ind, edge_max, edge_sum, in_feat, attn_row, attn_col = ctx.saved_tensors |
| 27 | grad_out = grad_out.contiguous() |
| 28 | grad_feat, grad_attn_row, grad_attn_col = fused_gatconv.gat_backward( |
| 29 | ctx.negative_slope, |
| 30 | row_ptr, |
| 31 | col_ind, |
| 32 | col_ptr, |
| 33 | row_ind, |
| 34 | edge_max, |
| 35 | edge_sum, |
| 36 | in_feat, |
| 37 | attn_row, |
| 38 | attn_col, |
| 39 | grad_out, |
| 40 | ) |
| 41 | return grad_attn_row, grad_attn_col, None, None, None, None, None, grad_feat, None |
nothing calls this directly
no test coverage detected