(dZ)
| 140 | out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y) |
| 141 | |
| 142 | def grad(dZ): |
| 143 | dZ = tensor(dZ) |
| 144 | if op != "copy_rhs": |
| 145 | g_rev = gidx.reverse() |
| 146 | if reduce_op == "sum": |
| 147 | if op in ["mul", "div"]: |
| 148 | dX = _gspmm(g_rev, "mul", "sum", dZ, _muldiv(op, Y))[0] |
| 149 | elif op in ["add", "sub"]: |
| 150 | dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, Y)[0] |
| 151 | elif op == "copy_lhs": |
| 152 | dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, None)[0] |
| 153 | else: |
| 154 | if op in ["mul", "div"]: |
| 155 | dX = _scatter_nd( |
| 156 | argX, |
| 157 | _muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:]))) |
| 158 | * dZ, |
| 159 | X.shape[0], |
| 160 | ) |
| 161 | elif op in ["add", "sub", "copy_lhs"]: |
| 162 | dX = _scatter_nd(argX, dZ, X.shape[0]) |
| 163 | dX = _reduce_grad(dX, X.shape) |
| 164 | else: |
| 165 | dX = tf.zeros_like(X) |
| 166 | if op != "copy_lhs": |
| 167 | if reduce_op == "sum": |
| 168 | if op == "mul" and _need_reduce_last_dim(X, Y): |
| 169 | dY = _gsddmm(gidx, "dot", X, dZ) |
| 170 | elif op in ["mul", "div"]: |
| 171 | dY = _gsddmm(gidx, "mul", X, dZ) |
| 172 | if op == "div": |
| 173 | dY = -dY / (Y**2) |
| 174 | elif op in ["add", "sub", "copy_rhs"]: |
| 175 | dY = _gsddmm(gidx, "copy_rhs", X, _addsub(op, dZ)) |
| 176 | else: |
| 177 | out_shp = (Y.shape[0],) + dZ.shape[1:] |
| 178 | if op in ["mul", "div"]: |
| 179 | dY = _scatter_nd( |
| 180 | argY, |
| 181 | _gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ, |
| 182 | Y.shape[0], |
| 183 | ) |
| 184 | if op == "div": |
| 185 | dY = -dY / (Y**2) |
| 186 | elif op in ["add", "sub", "copy_rhs"]: |
| 187 | dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0]) |
| 188 | dY = _reduce_grad(dY, Y.shape) |
| 189 | else: |
| 190 | dY = tf.zeros_like(Y) |
| 191 | return dX, dY |
| 192 | |
| 193 | return out, grad |
| 194 |
nothing calls this directly
no test coverage detected