(ctx, grad_out)
| 55 | |
| 56 | @staticmethod |
| 57 | def backward(ctx, grad_out): |
| 58 | if len(ctx.backward_csc) == 5: |
| 59 | rowptr, colind, feat, edge_weight_csr, sym = ctx.backward_csc |
| 60 | else: |
| 61 | rowptr, colind, edge_weight_csr, sym = ctx.backward_csc |
| 62 | if edge_weight_csr is not None: |
| 63 | grad_out = grad_out.contiguous() |
| 64 | if sym: |
| 65 | colptr, rowind, edge_weight_csc = rowptr, colind, edge_weight_csr |
| 66 | else: |
| 67 | colptr, rowind, edge_weight_csc = spmm.csr2csc(rowptr, colind, edge_weight_csr) |
| 68 | grad_feat = spmm.csr_spmm(colptr, rowind, edge_weight_csc, grad_out) |
| 69 | if edge_weight_csr.requires_grad: |
| 70 | grad_edge_weight = sddmm.csr_sddmm(rowptr, colind, grad_out, feat) |
| 71 | else: |
| 72 | grad_edge_weight = None |
| 73 | else: |
| 74 | if sym is False: |
| 75 | colptr, rowind, edge_weight_csc = spmm.csr2csc(rowptr, colind, edge_weight_csr) |
| 76 | grad_feat = spmm.csr_spmm_no_edge_value(colptr, rowind, grad_out) |
| 77 | else: |
| 78 | grad_feat = spmm.csr_spmm_no_edge_value(rowptr, colind, grad_out) |
| 79 | grad_edge_weight = None |
| 80 | return None, None, grad_feat, grad_edge_weight, None |
| 81 | |
| 82 | |
| 83 | try: |
nothing calls this directly
no test coverage detected