(ctx, grad_out)
| 103 | |
| 104 | @staticmethod |
| 105 | def backward(ctx, grad_out): |
| 106 | if len(ctx.backward_csc) == 5: |
| 107 | rowptr, colind, quantized, edge_weight_csr, sym = ctx.backward_csc |
| 108 | q_input_shape = ctx.other_args |
| 109 | feat = dequantize_activation(quantized, q_input_shape) |
| 110 | del quantized |
| 111 | else: |
| 112 | rowptr, colind, edge_weight_csr, sym = ctx.backward_csc |
| 113 | del ctx.backward_csc |
| 114 | |
| 115 | if edge_weight_csr is not None: |
| 116 | grad_out = grad_out.contiguous() |
| 117 | if sym: |
| 118 | colptr, rowind, edge_weight_csc = rowptr, colind, edge_weight_csr |
| 119 | else: |
| 120 | colptr, rowind, edge_weight_csc = spmm.csr2csc(rowptr, colind, edge_weight_csr) |
| 121 | grad_feat = spmm.csr_spmm(colptr, rowind, edge_weight_csc, grad_out) |
| 122 | if edge_weight_csr.requires_grad: |
| 123 | grad_edge_weight = sddmm.csr_sddmm(rowptr, colind, grad_out, feat) |
| 124 | else: |
| 125 | grad_edge_weight = None |
| 126 | else: |
| 127 | if sym is False: |
| 128 | colptr, rowind, edge_weight_csc = spmm.csr2csc(rowptr, colind, edge_weight_csr) |
| 129 | grad_feat = spmm.csr_spmm_no_edge_value(colptr, rowind, grad_out) |
| 130 | else: |
| 131 | grad_feat = spmm.csr_spmm_no_edge_value(rowptr, colind, grad_out) |
| 132 | grad_edge_weight = None |
| 133 | return None, None, grad_feat, grad_edge_weight, None |
nothing calls this directly
no test coverage detected