MCPcopy
hub / github.com/THUDM/CogDL / backward

Method backward

cogdl/operators/spmm.py:57–80  ·  view source on GitHub ↗
(ctx, grad_out)

Source from the content-addressed store, hash-verified

55
56 @staticmethod
57 def backward(ctx, grad_out):
58 if len(ctx.backward_csc) == 5:
59 rowptr, colind, feat, edge_weight_csr, sym = ctx.backward_csc
60 else:
61 rowptr, colind, edge_weight_csr, sym = ctx.backward_csc
62 if edge_weight_csr is not None:
63 grad_out = grad_out.contiguous()
64 if sym:
65 colptr, rowind, edge_weight_csc = rowptr, colind, edge_weight_csr
66 else:
67 colptr, rowind, edge_weight_csc = spmm.csr2csc(rowptr, colind, edge_weight_csr)
68 grad_feat = spmm.csr_spmm(colptr, rowind, edge_weight_csc, grad_out)
69 if edge_weight_csr.requires_grad:
70 grad_edge_weight = sddmm.csr_sddmm(rowptr, colind, grad_out, feat)
71 else:
72 grad_edge_weight = None
73 else:
74 if sym is False:
75 colptr, rowind, edge_weight_csc = spmm.csr2csc(rowptr, colind, edge_weight_csr)
76 grad_feat = spmm.csr_spmm_no_edge_value(colptr, rowind, grad_out)
77 else:
78 grad_feat = spmm.csr_spmm_no_edge_value(rowptr, colind, grad_out)
79 grad_edge_weight = None
80 return None, None, grad_feat, grad_edge_weight, None
81
82
83try:

Callers

nothing calls this directly

Calls 1

contiguousMethod · 0.80

Tested by

no test coverage detected