MCPcopy
hub / github.com/dmlc/dgl / test_all_binary_builtins

Function test_all_binary_builtins

tests/python/common/test_heterograph-kernel.py:226–397  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

224
225
226def test_all_binary_builtins():
227 def _test(g, lhs, rhs, binary_op, reducer, partial, nid, broadcast="none"):
228 # initialize node/edge features with uniform(-1, 1)
229 hu, hv, he = generate_feature(g, broadcast, binary_op)
230 if binary_op == "div":
231 # op = div
232 # lhs range: [-1, 1]
233 # rhs range: [1, 2]
234 # result range: [-1, 1]
235 if rhs == "u":
236 hu = (hu + 3) / 2
237 elif rhs == "v":
238 hv = (hv + 3) / 2
239 elif rhs == "e":
240 he = (he + 3) / 2
241
242 if binary_op == "add" or binary_op == "sub":
243 # op = add, sub
244 # lhs range: [-1/2, 1/2]
245 # rhs range: [-1/2, 1/2]
246 # result range: [-1, 1]
247 hu = hu / 2
248 hv = hv / 2
249 he = he / 2
250
251 g.ndata["u"] = F.attach_grad(F.clone(hu))
252 g.ndata["v"] = F.attach_grad(F.clone(hv))
253 g.edata["e"] = F.attach_grad(F.clone(he))
254
255 builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
256 builtin_msg = getattr(fn, builtin_msg_name)
257 builtin_red = getattr(fn, reducer)
258
259 def target_feature_switch(g, target):
260 if target == "u":
261 return g.ndata["u"]
262 elif target == "v":
263 return g.ndata["v"]
264 else:
265 return g.edata["e"]
266
267 with F.record_grad():
268 if partial:
269 g.pull(nid, builtin_msg(lhs, rhs, "m"), builtin_red("m", "r1"))
270 else:
271 g.update_all(builtin_msg(lhs, rhs, "m"), builtin_red("m", "r1"))
272 r1 = g.ndata.pop("r1")
273 F.backward(F.reduce_sum(r1))
274 lhs_grad_1 = F.grad(target_feature_switch(g, lhs))
275 rhs_grad_1 = F.grad(target_feature_switch(g, rhs))
276
277 # reset grad
278 g.ndata["u"] = F.attach_grad(F.clone(hu))
279 g.ndata["v"] = F.attach_grad(F.clone(hv))
280 g.edata["e"] = F.attach_grad(F.clone(he))
281
282 def target_switch(edges, target):
283 if target == "u":

Callers 1

Calls 7

_testFunction · 0.70
graphMethod · 0.45
add_nodesMethod · 0.45
add_edgesMethod · 0.45
nodesMethod · 0.45
toMethod · 0.45
ctxMethod · 0.45

Tested by

no test coverage detected