(red, partial)
| 89 | |
| 90 | def test_copy_src_reduce(): |
| 91 | def _test(red, partial): |
| 92 | g = dgl.from_networkx(nx.erdos_renyi_graph(100, 0.1)) |
| 93 | # NOTE(zihao): add self-loop to avoid zero-degree nodes. |
| 94 | # https://github.com/dmlc/dgl/issues/761 |
| 95 | g.add_edges(g.nodes(), g.nodes()) |
| 96 | g = g.to(F.ctx()) |
| 97 | hu, hv, he = generate_feature(g, "none", "none") |
| 98 | if partial: |
| 99 | nid = F.tensor(list(range(0, 100, 2)), g.idtype) |
| 100 | |
| 101 | g.ndata["u"] = F.attach_grad(F.clone(hu)) |
| 102 | g.ndata["v"] = F.attach_grad(F.clone(hv)) |
| 103 | g.edata["e"] = F.attach_grad(F.clone(he)) |
| 104 | |
| 105 | with F.record_grad(): |
| 106 | if partial: |
| 107 | g.pull( |
| 108 | nid, |
| 109 | fn.copy_u(u="u", out="m"), |
| 110 | builtin[red](msg="m", out="r1"), |
| 111 | ) |
| 112 | else: |
| 113 | g.update_all( |
| 114 | fn.copy_u(u="u", out="m"), builtin[red](msg="m", out="r1") |
| 115 | ) |
| 116 | r1 = g.ndata["r1"] |
| 117 | F.backward(F.reduce_sum(r1)) |
| 118 | n_grad1 = F.grad(g.ndata["u"]) |
| 119 | |
| 120 | # reset grad |
| 121 | g.ndata["u"] = F.attach_grad(F.clone(hu)) |
| 122 | g.ndata["v"] = F.attach_grad(F.clone(hv)) |
| 123 | g.edata["e"] = F.attach_grad(F.clone(he)) |
| 124 | |
| 125 | with F.record_grad(): |
| 126 | if partial: |
| 127 | g.pull(nid, udf_copy_src, udf_reduce[red]) |
| 128 | else: |
| 129 | g.update_all(udf_copy_src, udf_reduce[red]) |
| 130 | r2 = g.ndata["r2"] |
| 131 | F.backward(F.reduce_sum(r2)) |
| 132 | n_grad2 = F.grad(g.ndata["u"]) |
| 133 | |
| 134 | def _print_error(a, b): |
| 135 | print("ERROR: Test copy_src_{} partial: {}".format(red, partial)) |
| 136 | for i, (x, y) in enumerate( |
| 137 | zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten()) |
| 138 | ): |
| 139 | if not np.allclose(x, y): |
| 140 | print("@{} {} v.s. {}".format(i, x, y)) |
| 141 | |
| 142 | if not F.allclose(r1, r2): |
| 143 | _print_error(r1, r2) |
| 144 | assert F.allclose(r1, r2) |
| 145 | if not F.allclose(n_grad1, n_grad2): |
| 146 | print("node grad") |
| 147 | _print_error(n_grad1, n_grad2) |
| 148 | assert F.allclose(n_grad1, n_grad2) |
no test coverage detected