MCPcopy
hub / github.com/dmlc/dgl / test_sddmm

Function test_sddmm

tests/python/common/ops/test_ops.py:234–299  ·  view source on GitHub ↗
(g, shp, lhs_target, rhs_target, msg, idtype)

Source from the content-addressed store, hash-verified

232)
233@parametrize_idtype
234def test_sddmm(g, shp, lhs_target, rhs_target, msg, idtype):
235 if lhs_target == rhs_target:
236 return
237 g = g.astype(idtype).to(F.ctx())
238 if dgl.backend.backend_name == "mxnet" and g.num_edges() == 0:
239 pytest.skip() # mxnet do not support zero shape tensor
240 print(g)
241 print(g.idtype)
242
243 len_lhs = select(
244 lhs_target,
245 g.number_of_src_nodes(),
246 g.num_edges(),
247 g.number_of_dst_nodes(),
248 )
249 lhs_shp = (len_lhs,) + shp[0]
250 len_rhs = select(
251 rhs_target,
252 g.number_of_src_nodes(),
253 g.num_edges(),
254 g.number_of_dst_nodes(),
255 )
256 rhs_shp = (len_rhs,) + shp[1]
257 feat_lhs = F.tensor(np.random.rand(*lhs_shp) + 1)
258 feat_rhs = F.tensor(np.random.rand(*rhs_shp) + 1)
259 print(
260 "lhs shape: {}, rhs shape: {}".format(
261 F.shape(feat_lhs), F.shape(feat_rhs)
262 )
263 )
264
265 lhs_frame = select(lhs_target, g.srcdata, g.edata, g.dstdata)
266 rhs_frame = select(rhs_target, g.srcdata, g.edata, g.dstdata)
267 lhs_frame["x"] = F.attach_grad(F.clone(feat_lhs))
268 rhs_frame["y"] = F.attach_grad(F.clone(feat_rhs))
269 msg_func = lhs_target + "_" + msg + "_" + rhs_target
270 print("SDDMM(message func: {})".format(msg_func))
271
272 lhs = F.attach_grad(F.clone(feat_lhs))
273 rhs = F.attach_grad(F.clone(feat_rhs))
274 with F.record_grad():
275 e = gsddmm(
276 g, msg, lhs, rhs, lhs_target=lhs_target, rhs_target=rhs_target
277 )
278 F.backward(F.reduce_sum(e))
279 grad_lhs = F.grad(lhs)
280 grad_rhs = F.grad(rhs)
281
282 with F.record_grad():
283 g.apply_edges(udf_apply_edges[msg_func])
284 if g.num_edges() > 0:
285 e1 = g.edata["m"]
286 assert F.allclose(e, e1)
287 print("forward passed")
288
289 F.backward(F.reduce_sum(e1))
290 if msg != "copy_rhs":
291 assert F.allclose(F.grad(lhs_frame["x"]), grad_lhs)

Callers

nothing calls this directly

Calls 14

gsddmmFunction · 0.90
selectFunction · 0.85
number_of_src_nodesMethod · 0.80
number_of_dst_nodesMethod · 0.80
formatMethod · 0.80
gradMethod · 0.80
toMethod · 0.45
astypeMethod · 0.45
ctxMethod · 0.45
num_edgesMethod · 0.45
shapeMethod · 0.45
cloneMethod · 0.45

Tested by

no test coverage detected