MCPcopy Index your code
hub / github.com/dmlc/dgl / test_copy_edge_reduce

Function test_copy_edge_reduce

tests/python/common/test_heterograph-kernel.py:158–223  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

156
157
158def test_copy_edge_reduce():
159 def _test(red, partial):
160 g = dgl.from_networkx(nx.erdos_renyi_graph(100, 0.1))
161 # NOTE(zihao): add self-loop to avoid zero-degree nodes.
162 g.add_edges(g.nodes(), g.nodes())
163 g = g.to(F.ctx())
164 hu, hv, he = generate_feature(g, "none", "none")
165 if partial:
166 nid = F.tensor(list(range(0, 100, 2)), g.idtype)
167
168 g.ndata["u"] = F.attach_grad(F.clone(hu))
169 g.ndata["v"] = F.attach_grad(F.clone(hv))
170 g.edata["e"] = F.attach_grad(F.clone(he))
171
172 with F.record_grad():
173 if partial:
174 g.pull(
175 nid,
176 fn.copy_e(e="e", out="m"),
177 builtin[red](msg="m", out="r1"),
178 )
179 else:
180 g.update_all(
181 fn.copy_e(e="e", out="m"), builtin[red](msg="m", out="r1")
182 )
183 r1 = g.ndata["r1"]
184 F.backward(F.reduce_sum(r1))
185 e_grad1 = F.grad(g.edata["e"])
186
187 # reset grad
188 g.ndata["u"] = F.attach_grad(F.clone(hu))
189 g.ndata["v"] = F.attach_grad(F.clone(hv))
190 g.edata["e"] = F.attach_grad(F.clone(he))
191
192 with F.record_grad():
193 if partial:
194 g.pull(nid, udf_copy_edge, udf_reduce[red])
195 else:
196 g.update_all(udf_copy_edge, udf_reduce[red])
197 r2 = g.ndata["r2"]
198 F.backward(F.reduce_sum(r2))
199 e_grad2 = F.grad(g.edata["e"])
200
201 def _print_error(a, b):
202 print("ERROR: Test copy_edge_{} partial: {}".format(red, partial))
203 return
204 for i, (x, y) in enumerate(
205 zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
206 ):
207 if not np.allclose(x, y):
208 print("@{} {} v.s. {}".format(i, x, y))
209
210 if not F.allclose(r1, r2):
211 _print_error(r1, r2)
212 assert F.allclose(r1, r2)
213 if not F.allclose(e_grad1, e_grad2):
214 print("edge gradient")
215 _print_error(e_grad1, e_grad2)

Callers 1

Calls 1

_testFunction · 0.70

Tested by

no test coverage detected