| 224 | |
| 225 | |
| 226 | def test_all_binary_builtins(): |
| 227 | def _test(g, lhs, rhs, binary_op, reducer, partial, nid, broadcast="none"): |
| 228 | # initialize node/edge features with uniform(-1, 1) |
| 229 | hu, hv, he = generate_feature(g, broadcast, binary_op) |
| 230 | if binary_op == "div": |
| 231 | # op = div |
| 232 | # lhs range: [-1, 1] |
| 233 | # rhs range: [1, 2] |
| 234 | # result range: [-1, 1] |
| 235 | if rhs == "u": |
| 236 | hu = (hu + 3) / 2 |
| 237 | elif rhs == "v": |
| 238 | hv = (hv + 3) / 2 |
| 239 | elif rhs == "e": |
| 240 | he = (he + 3) / 2 |
| 241 | |
| 242 | if binary_op == "add" or binary_op == "sub": |
| 243 | # op = add, sub |
| 244 | # lhs range: [-1/2, 1/2] |
| 245 | # rhs range: [-1/2, 1/2] |
| 246 | # result range: [-1, 1] |
| 247 | hu = hu / 2 |
| 248 | hv = hv / 2 |
| 249 | he = he / 2 |
| 250 | |
| 251 | g.ndata["u"] = F.attach_grad(F.clone(hu)) |
| 252 | g.ndata["v"] = F.attach_grad(F.clone(hv)) |
| 253 | g.edata["e"] = F.attach_grad(F.clone(he)) |
| 254 | |
| 255 | builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs) |
| 256 | builtin_msg = getattr(fn, builtin_msg_name) |
| 257 | builtin_red = getattr(fn, reducer) |
| 258 | |
| 259 | def target_feature_switch(g, target): |
| 260 | if target == "u": |
| 261 | return g.ndata["u"] |
| 262 | elif target == "v": |
| 263 | return g.ndata["v"] |
| 264 | else: |
| 265 | return g.edata["e"] |
| 266 | |
| 267 | with F.record_grad(): |
| 268 | if partial: |
| 269 | g.pull(nid, builtin_msg(lhs, rhs, "m"), builtin_red("m", "r1")) |
| 270 | else: |
| 271 | g.update_all(builtin_msg(lhs, rhs, "m"), builtin_red("m", "r1")) |
| 272 | r1 = g.ndata.pop("r1") |
| 273 | F.backward(F.reduce_sum(r1)) |
| 274 | lhs_grad_1 = F.grad(target_feature_switch(g, lhs)) |
| 275 | rhs_grad_1 = F.grad(target_feature_switch(g, rhs)) |
| 276 | |
| 277 | # reset grad |
| 278 | g.ndata["u"] = F.attach_grad(F.clone(hu)) |
| 279 | g.ndata["v"] = F.attach_grad(F.clone(hv)) |
| 280 | g.edata["e"] = F.attach_grad(F.clone(he)) |
| 281 | |
| 282 | def target_switch(edges, target): |
| 283 | if target == "u": |