(mfunc)
| 69 | @parametrize_idtype |
| 70 | def test_unary_copy_u(idtype): |
| 71 | def _test(mfunc): |
| 72 | g = create_test_heterograph(idtype) |
| 73 | |
| 74 | x1 = F.randn((g.num_nodes("user"), feat_size)) |
| 75 | x2 = F.randn((g.num_nodes("developer"), feat_size)) |
| 76 | |
| 77 | F.attach_grad(x1) |
| 78 | F.attach_grad(x2) |
| 79 | g.nodes["user"].data["h"] = x1 |
| 80 | g.nodes["developer"].data["h"] = x2 |
| 81 | |
| 82 | ################################################################# |
| 83 | # apply_edges() is called on each relation type separately |
| 84 | ################################################################# |
| 85 | |
| 86 | with F.record_grad(): |
| 87 | [ |
| 88 | g.apply_edges(fn.copy_u("h", "m"), etype=rel) |
| 89 | for rel in g.canonical_etypes |
| 90 | ] |
| 91 | r1 = g["plays"].edata["m"] |
| 92 | F.backward(r1, F.ones(r1.shape)) |
| 93 | n_grad1 = F.grad(g.ndata["h"]["user"]) |
| 94 | # TODO (Israt): clear not working |
| 95 | g.edata["m"].clear() |
| 96 | |
| 97 | ################################################################# |
| 98 | # apply_edges() is called on all relation types |
| 99 | ################################################################# |
| 100 | |
| 101 | g.apply_edges(fn.copy_u("h", "m")) |
| 102 | r2 = g["plays"].edata["m"] |
| 103 | F.backward(r2, F.ones(r2.shape)) |
| 104 | n_grad2 = F.grad(g.nodes["user"].data["h"]) |
| 105 | |
| 106 | # correctness check |
| 107 | def _print_error(a, b): |
| 108 | for i, (x, y) in enumerate( |
| 109 | zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten()) |
| 110 | ): |
| 111 | if not np.allclose(x, y): |
| 112 | print("@{} {} v.s. {}".format(i, x, y)) |
| 113 | |
| 114 | if not F.allclose(r1, r2): |
| 115 | _print_error(r1, r2) |
| 116 | assert F.allclose(r1, r2) |
| 117 | if not F.allclose(n_grad1, n_grad2): |
| 118 | print("node grad") |
| 119 | _print_error(n_grad1, n_grad2) |
| 120 | assert F.allclose(n_grad1, n_grad2) |
| 121 | |
| 122 | _test(fn.copy_u) |
| 123 |
no test coverage detected