MCPcopy
hub / github.com/dmlc/dgl / grad

Function grad

python/dgl/backend/tensorflow/sparse.py:142–191  ·  view source on GitHub ↗
(dZ)

Source from the content-addressed store, hash-verified

140 out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
141
142 def grad(dZ):
143 dZ = tensor(dZ)
144 if op != "copy_rhs":
145 g_rev = gidx.reverse()
146 if reduce_op == "sum":
147 if op in ["mul", "div"]:
148 dX = _gspmm(g_rev, "mul", "sum", dZ, _muldiv(op, Y))[0]
149 elif op in ["add", "sub"]:
150 dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, Y)[0]
151 elif op == "copy_lhs":
152 dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, None)[0]
153 else:
154 if op in ["mul", "div"]:
155 dX = _scatter_nd(
156 argX,
157 _muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:])))
158 * dZ,
159 X.shape[0],
160 )
161 elif op in ["add", "sub", "copy_lhs"]:
162 dX = _scatter_nd(argX, dZ, X.shape[0])
163 dX = _reduce_grad(dX, X.shape)
164 else:
165 dX = tf.zeros_like(X)
166 if op != "copy_lhs":
167 if reduce_op == "sum":
168 if op == "mul" and _need_reduce_last_dim(X, Y):
169 dY = _gsddmm(gidx, "dot", X, dZ)
170 elif op in ["mul", "div"]:
171 dY = _gsddmm(gidx, "mul", X, dZ)
172 if op == "div":
173 dY = -dY / (Y**2)
174 elif op in ["add", "sub", "copy_rhs"]:
175 dY = _gsddmm(gidx, "copy_rhs", X, _addsub(op, dZ))
176 else:
177 out_shp = (Y.shape[0],) + dZ.shape[1:]
178 if op in ["mul", "div"]:
179 dY = _scatter_nd(
180 argY,
181 _gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ,
182 Y.shape[0],
183 )
184 if op == "div":
185 dY = -dY / (Y**2)
186 elif op in ["add", "sub", "copy_rhs"]:
187 dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
188 dY = _reduce_grad(dY, Y.shape)
189 else:
190 dY = tf.zeros_like(Y)
191 return dX, dY
192
193 return out, grad
194

Callers

nothing calls this directly

Calls 14

_gspmmFunction · 0.85
_gsddmmFunction · 0.85
_csrmmFunction · 0.85
_csrmaskFunction · 0.85
reverseMethod · 0.80
number_of_ntypesMethod · 0.80
tensorFunction · 0.70
_muldivFunction · 0.70
_scatter_ndFunction · 0.70
_gather_ndFunction · 0.70
_expandFunction · 0.70
_reduce_gradFunction · 0.70

Tested by

no test coverage detected