()
| 156 | |
| 157 | |
| 158 | def test_copy_edge_reduce(): |
| 159 | def _test(red, partial): |
| 160 | g = dgl.from_networkx(nx.erdos_renyi_graph(100, 0.1)) |
| 161 | # NOTE(zihao): add self-loop to avoid zero-degree nodes. |
| 162 | g.add_edges(g.nodes(), g.nodes()) |
| 163 | g = g.to(F.ctx()) |
| 164 | hu, hv, he = generate_feature(g, "none", "none") |
| 165 | if partial: |
| 166 | nid = F.tensor(list(range(0, 100, 2)), g.idtype) |
| 167 | |
| 168 | g.ndata["u"] = F.attach_grad(F.clone(hu)) |
| 169 | g.ndata["v"] = F.attach_grad(F.clone(hv)) |
| 170 | g.edata["e"] = F.attach_grad(F.clone(he)) |
| 171 | |
| 172 | with F.record_grad(): |
| 173 | if partial: |
| 174 | g.pull( |
| 175 | nid, |
| 176 | fn.copy_e(e="e", out="m"), |
| 177 | builtin[red](msg="m", out="r1"), |
| 178 | ) |
| 179 | else: |
| 180 | g.update_all( |
| 181 | fn.copy_e(e="e", out="m"), builtin[red](msg="m", out="r1") |
| 182 | ) |
| 183 | r1 = g.ndata["r1"] |
| 184 | F.backward(F.reduce_sum(r1)) |
| 185 | e_grad1 = F.grad(g.edata["e"]) |
| 186 | |
| 187 | # reset grad |
| 188 | g.ndata["u"] = F.attach_grad(F.clone(hu)) |
| 189 | g.ndata["v"] = F.attach_grad(F.clone(hv)) |
| 190 | g.edata["e"] = F.attach_grad(F.clone(he)) |
| 191 | |
| 192 | with F.record_grad(): |
| 193 | if partial: |
| 194 | g.pull(nid, udf_copy_edge, udf_reduce[red]) |
| 195 | else: |
| 196 | g.update_all(udf_copy_edge, udf_reduce[red]) |
| 197 | r2 = g.ndata["r2"] |
| 198 | F.backward(F.reduce_sum(r2)) |
| 199 | e_grad2 = F.grad(g.edata["e"]) |
| 200 | |
| 201 | def _print_error(a, b): |
| 202 | print("ERROR: Test copy_edge_{} partial: {}".format(red, partial)) |
| 203 | return |
| 204 | for i, (x, y) in enumerate( |
| 205 | zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten()) |
| 206 | ): |
| 207 | if not np.allclose(x, y): |
| 208 | print("@{} {} v.s. {}".format(i, x, y)) |
| 209 | |
| 210 | if not F.allclose(r1, r2): |
| 211 | _print_error(r1, r2) |
| 212 | assert F.allclose(r1, r2) |
| 213 | if not F.allclose(e_grad1, e_grad2): |
| 214 | print("edge gradient") |
| 215 | _print_error(e_grad1, e_grad2) |
no test coverage detected