(mfunc, rfunc)
| 90 | @parametrize_idtype |
| 91 | def test_unary_copy_u(idtype): |
| 92 | def _test(mfunc, rfunc): |
| 93 | g = create_test_heterograph_2(idtype) |
| 94 | g0 = create_test_heterograph(idtype) |
| 95 | g1 = create_test_heterograph_large(idtype) |
| 96 | cross_reducer = rfunc.__name__ |
| 97 | x1 = F.randn((g.num_nodes("user"), feat_size)) |
| 98 | x2 = F.randn((g.num_nodes("developer"), feat_size)) |
| 99 | F.attach_grad(x1) |
| 100 | F.attach_grad(x2) |
| 101 | g.nodes["user"].data["h"] = x1 |
| 102 | g.nodes["developer"].data["h"] = x2 |
| 103 | |
| 104 | ################################################################# |
| 105 | # multi_update_all(): call msg_passing separately for each etype |
| 106 | ################################################################# |
| 107 | |
| 108 | with F.record_grad(): |
| 109 | g.multi_update_all( |
| 110 | { |
| 111 | etype: (mfunc("h", "m"), rfunc("m", "y")) |
| 112 | for etype in g.canonical_etypes |
| 113 | }, |
| 114 | cross_reducer, |
| 115 | ) |
| 116 | r1 = g.nodes["game"].data["y"].clone() |
| 117 | r2 = g.nodes["user"].data["y"].clone() |
| 118 | r3 = g.nodes["player"].data["y"].clone() |
| 119 | loss = r1.sum() + r2.sum() + r3.sum() |
| 120 | F.backward(loss) |
| 121 | n_grad1 = F.grad(g.nodes["user"].data["h"]).clone() |
| 122 | n_grad2 = F.grad(g.nodes["developer"].data["h"]).clone() |
| 123 | |
| 124 | g.nodes["user"].data.clear() |
| 125 | g.nodes["developer"].data.clear() |
| 126 | g.nodes["game"].data.clear() |
| 127 | g.nodes["player"].data.clear() |
| 128 | |
| 129 | ################################################################# |
| 130 | # update_all(): call msg_passing for all etypes |
| 131 | ################################################################# |
| 132 | |
| 133 | F.attach_grad(x1) |
| 134 | F.attach_grad(x2) |
| 135 | g.nodes["user"].data["h"] = x1 |
| 136 | g.nodes["developer"].data["h"] = x2 |
| 137 | |
| 138 | with F.record_grad(): |
| 139 | g.update_all(mfunc("h", "m"), rfunc("m", "y")) |
| 140 | r4 = g.nodes["game"].data["y"] |
| 141 | r5 = g.nodes["user"].data["y"] |
| 142 | r6 = g.nodes["player"].data["y"] |
| 143 | loss = r4.sum() + r5.sum() + r6.sum() |
| 144 | F.backward(loss) |
| 145 | n_grad3 = F.grad(g.nodes["user"].data["h"]) |
| 146 | n_grad4 = F.grad(g.nodes["developer"].data["h"]) |
| 147 | |
| 148 | assert F.allclose(r1, r4) |
| 149 | assert F.allclose(r2, r5) |
no test coverage detected