MCPcopy
hub / github.com/dmlc/dgl / backward

Method backward

python/dgl/backend/pytorch/sparse.py:195–248  ·  view source on GitHub ↗
(ctx, dZ)

Source from the content-addressed store, hash-verified

193
194 @staticmethod
195 def backward(ctx, dZ):
196 (
197 gidx,
198 op,
199 reduce_op,
200 X_shape,
201 Y_shape,
202 dtype,
203 device,
204 reduce_last,
205 ) = ctx.backward_cache
206 X, Y, argX, argY = ctx.saved_tensors
207 if op != "copy_rhs" and ctx.needs_input_grad[3]:
208 g_rev = gidx.reverse()
209 if reduce_op == "sum":
210 if op == "mul":
211 dX = gspmm(g_rev, "mul", "sum", dZ, Y)
212 elif op == "add":
213 dX = gspmm(g_rev, "copy_lhs", "sum", dZ, Y)
214 elif op == "copy_lhs":
215 dX = gspmm(g_rev, "copy_lhs", "sum", dZ, None)
216 else: # max/min
217 dX = th.zeros(
218 (X_shape[0],) + dZ.shape[1:], dtype=dtype, device=device
219 )
220 if op == "mul":
221 grad = _expand(Y, dZ.shape[1:]).gather(0, argY.long()) * dZ
222 dX.scatter_add_(0, argX.long(), grad)
223 elif op in ["add", "copy_lhs"]:
224 dX.scatter_add_(0, argX.long(), dZ)
225 dX = _reduce_grad(dX, X_shape)
226 else: # X has not gradient
227 dX = None
228 if op != "copy_lhs" and ctx.needs_input_grad[4]:
229 if reduce_op == "sum":
230 if op == "mul" and reduce_last:
231 dY = gsddmm(gidx, "dot", X, dZ)
232 elif op == "mul":
233 dY = gsddmm(gidx, "mul", X, dZ)
234 elif op in ["add", "copy_rhs"]:
235 dY = gsddmm(gidx, "copy_rhs", X, dZ)
236 else: # max/min
237 dY = th.zeros(
238 (Y_shape[0],) + dZ.shape[1:], dtype=dtype, device=device
239 )
240 if op == "mul":
241 grad = _expand(X, dZ.shape[1:]).gather(0, argX.long()) * dZ
242 dY.scatter_add_(0, argY.long(), grad)
243 elif op in ["add", "copy_rhs"]:
244 dY.scatter_add_(0, argY.long(), dZ)
245 dY = _reduce_grad(dY, Y_shape)
246 else: # Y has no gradient
247 dY = None
248 return None, None, None, dX, dY
249
250
251class GSpMM_hetero(th.autograd.Function):

Callers 1

backwardFunction · 0.45

Calls 6

reverseMethod · 0.80
gspmmFunction · 0.70
_expandFunction · 0.70
_reduce_gradFunction · 0.70
gsddmmFunction · 0.70
longMethod · 0.45

Tested by

no test coverage detected