MCPcopy
hub / github.com/dmlc/dgl / test_binary_op

Function test_binary_op

tests/python/common/test_heterograph-update-all.py:252–334  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

250
251@parametrize_idtype
252def test_binary_op(idtype):
253 def _test(lhs, rhs, binary_op, reducer):
254 g = create_test_heterograph(idtype)
255
256 x1 = F.randn((g.num_nodes("user"), feat_size))
257 x2 = F.randn((g.num_nodes("developer"), feat_size))
258 x3 = F.randn((g.num_nodes("game"), feat_size))
259
260 F.attach_grad(x1)
261 F.attach_grad(x2)
262 F.attach_grad(x3)
263 g.nodes["user"].data["h"] = x1
264 g.nodes["developer"].data["h"] = x2
265 g.nodes["game"].data["h"] = x3
266
267 x1 = F.randn((4, feat_size))
268 x2 = F.randn((4, feat_size))
269 x3 = F.randn((3, feat_size))
270 x4 = F.randn((3, feat_size))
271 F.attach_grad(x1)
272 F.attach_grad(x2)
273 F.attach_grad(x3)
274 F.attach_grad(x4)
275 g["plays"].edata["h"] = x1
276 g["follows"].edata["h"] = x2
277 g["develops"].edata["h"] = x3
278 g["wishes"].edata["h"] = x4
279
280 builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
281 builtin_msg = getattr(fn, builtin_msg_name)
282 builtin_red = getattr(fn, reducer)
283
284 #################################################################
285 # multi_update_all(): call msg_passing separately for each etype
286 #################################################################
287
288 with F.record_grad():
289 g.multi_update_all(
290 {
291 etype: (builtin_msg("h", "h", "m"), builtin_red("m", "y"))
292 for etype in g.canonical_etypes
293 },
294 "sum",
295 )
296 r1 = g.nodes["game"].data["y"]
297 F.backward(r1, F.ones(r1.shape))
298 n_grad1 = F.grad(r1)
299
300 #################################################################
301 # update_all(): call msg_passing for all etypes
302 #################################################################
303
304 g.update_all(builtin_msg("h", "h", "m"), builtin_red("m", "y"))
305 r2 = g.nodes["game"].data["y"]
306 F.backward(r2, F.ones(r2.shape))
307 n_grad2 = F.grad(r2)
308
309 # correctness check

Callers 1

Calls 1

_testFunction · 0.70

Tested by

no test coverage detected