| 193 | |
| 194 | @staticmethod |
| 195 | def backward(ctx, dZ): |
| 196 | ( |
| 197 | gidx, |
| 198 | op, |
| 199 | reduce_op, |
| 200 | X_shape, |
| 201 | Y_shape, |
| 202 | dtype, |
| 203 | device, |
| 204 | reduce_last, |
| 205 | ) = ctx.backward_cache |
| 206 | X, Y, argX, argY = ctx.saved_tensors |
| 207 | if op != "copy_rhs" and ctx.needs_input_grad[3]: |
| 208 | g_rev = gidx.reverse() |
| 209 | if reduce_op == "sum": |
| 210 | if op == "mul": |
| 211 | dX = gspmm(g_rev, "mul", "sum", dZ, Y) |
| 212 | elif op == "add": |
| 213 | dX = gspmm(g_rev, "copy_lhs", "sum", dZ, Y) |
| 214 | elif op == "copy_lhs": |
| 215 | dX = gspmm(g_rev, "copy_lhs", "sum", dZ, None) |
| 216 | else: # max/min |
| 217 | dX = th.zeros( |
| 218 | (X_shape[0],) + dZ.shape[1:], dtype=dtype, device=device |
| 219 | ) |
| 220 | if op == "mul": |
| 221 | grad = _expand(Y, dZ.shape[1:]).gather(0, argY.long()) * dZ |
| 222 | dX.scatter_add_(0, argX.long(), grad) |
| 223 | elif op in ["add", "copy_lhs"]: |
| 224 | dX.scatter_add_(0, argX.long(), dZ) |
| 225 | dX = _reduce_grad(dX, X_shape) |
| 226 | else: # X has not gradient |
| 227 | dX = None |
| 228 | if op != "copy_lhs" and ctx.needs_input_grad[4]: |
| 229 | if reduce_op == "sum": |
| 230 | if op == "mul" and reduce_last: |
| 231 | dY = gsddmm(gidx, "dot", X, dZ) |
| 232 | elif op == "mul": |
| 233 | dY = gsddmm(gidx, "mul", X, dZ) |
| 234 | elif op in ["add", "copy_rhs"]: |
| 235 | dY = gsddmm(gidx, "copy_rhs", X, dZ) |
| 236 | else: # max/min |
| 237 | dY = th.zeros( |
| 238 | (Y_shape[0],) + dZ.shape[1:], dtype=dtype, device=device |
| 239 | ) |
| 240 | if op == "mul": |
| 241 | grad = _expand(X, dZ.shape[1:]).gather(0, argX.long()) * dZ |
| 242 | dY.scatter_add_(0, argY.long(), grad) |
| 243 | elif op in ["add", "copy_rhs"]: |
| 244 | dY.scatter_add_(0, argY.long(), dZ) |
| 245 | dY = _reduce_grad(dY, Y_shape) |
| 246 | else: # Y has no gradient |
| 247 | dY = None |
| 248 | return None, None, None, dX, dY |
| 249 | |
| 250 | |
| 251 | class GSpMM_hetero(th.autograd.Function): |