MCPcopy
hub / github.com/THUDM/CogDL / backward

Method backward

cogdl/operators/spmm.py:105–133  ·  view source on GitHub ↗
(ctx, grad_out)

Source from the content-addressed store, hash-verified

103
104 @staticmethod
105 def backward(ctx, grad_out):
106 if len(ctx.backward_csc) == 5:
107 rowptr, colind, quantized, edge_weight_csr, sym = ctx.backward_csc
108 q_input_shape = ctx.other_args
109 feat = dequantize_activation(quantized, q_input_shape)
110 del quantized
111 else:
112 rowptr, colind, edge_weight_csr, sym = ctx.backward_csc
113 del ctx.backward_csc
114
115 if edge_weight_csr is not None:
116 grad_out = grad_out.contiguous()
117 if sym:
118 colptr, rowind, edge_weight_csc = rowptr, colind, edge_weight_csr
119 else:
120 colptr, rowind, edge_weight_csc = spmm.csr2csc(rowptr, colind, edge_weight_csr)
121 grad_feat = spmm.csr_spmm(colptr, rowind, edge_weight_csc, grad_out)
122 if edge_weight_csr.requires_grad:
123 grad_edge_weight = sddmm.csr_sddmm(rowptr, colind, grad_out, feat)
124 else:
125 grad_edge_weight = None
126 else:
127 if sym is False:
128 colptr, rowind, edge_weight_csc = spmm.csr2csc(rowptr, colind, edge_weight_csr)
129 grad_feat = spmm.csr_spmm_no_edge_value(colptr, rowind, grad_out)
130 else:
131 grad_feat = spmm.csr_spmm_no_edge_value(rowptr, colind, grad_out)
132 grad_edge_weight = None
133 return None, None, grad_feat, grad_edge_weight, None

Callers

nothing calls this directly

Calls 1

contiguousMethod · 0.80

Tested by

no test coverage detected