(idtype)
| 250 | |
| 251 | @parametrize_idtype |
| 252 | def test_binary_op(idtype): |
| 253 | def _test(lhs, rhs, binary_op, reducer): |
| 254 | g = create_test_heterograph(idtype) |
| 255 | |
| 256 | x1 = F.randn((g.num_nodes("user"), feat_size)) |
| 257 | x2 = F.randn((g.num_nodes("developer"), feat_size)) |
| 258 | x3 = F.randn((g.num_nodes("game"), feat_size)) |
| 259 | |
| 260 | F.attach_grad(x1) |
| 261 | F.attach_grad(x2) |
| 262 | F.attach_grad(x3) |
| 263 | g.nodes["user"].data["h"] = x1 |
| 264 | g.nodes["developer"].data["h"] = x2 |
| 265 | g.nodes["game"].data["h"] = x3 |
| 266 | |
| 267 | x1 = F.randn((4, feat_size)) |
| 268 | x2 = F.randn((4, feat_size)) |
| 269 | x3 = F.randn((3, feat_size)) |
| 270 | x4 = F.randn((3, feat_size)) |
| 271 | F.attach_grad(x1) |
| 272 | F.attach_grad(x2) |
| 273 | F.attach_grad(x3) |
| 274 | F.attach_grad(x4) |
| 275 | g["plays"].edata["h"] = x1 |
| 276 | g["follows"].edata["h"] = x2 |
| 277 | g["develops"].edata["h"] = x3 |
| 278 | g["wishes"].edata["h"] = x4 |
| 279 | |
| 280 | builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs) |
| 281 | builtin_msg = getattr(fn, builtin_msg_name) |
| 282 | builtin_red = getattr(fn, reducer) |
| 283 | |
| 284 | ################################################################# |
| 285 | # multi_update_all(): call msg_passing separately for each etype |
| 286 | ################################################################# |
| 287 | |
| 288 | with F.record_grad(): |
| 289 | g.multi_update_all( |
| 290 | { |
| 291 | etype: (builtin_msg("h", "h", "m"), builtin_red("m", "y")) |
| 292 | for etype in g.canonical_etypes |
| 293 | }, |
| 294 | "sum", |
| 295 | ) |
| 296 | r1 = g.nodes["game"].data["y"] |
| 297 | F.backward(r1, F.ones(r1.shape)) |
| 298 | n_grad1 = F.grad(r1) |
| 299 | |
| 300 | ################################################################# |
| 301 | # update_all(): call msg_passing for all etypes |
| 302 | ################################################################# |
| 303 | |
| 304 | g.update_all(builtin_msg("h", "h", "m"), builtin_red("m", "y")) |
| 305 | r2 = g.nodes["game"].data["y"] |
| 306 | F.backward(r2, F.ones(r2.shape)) |
| 307 | n_grad2 = F.grad(r2) |
| 308 | |
| 309 | # correctness check |
no test coverage detected