MCPcopy Index your code
hub / github.com/dmlc/dgl / test_batch_setter_autograd

Function test_batch_setter_autograd

tests/python/common/function/test_basics.py:183–197  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

181
182@parametrize_idtype
183def test_batch_setter_autograd(idtype):
184 g = generate_graph(idtype, grad=True)
185 h1 = g.ndata["h"]
186 # partial set
187 v = F.tensor([1, 2, 8], g.idtype)
188 hh = F.attach_grad(F.zeros((len(v), D)))
189 with F.record_grad():
190 g.nodes[v].data["h"] = hh
191 h2 = g.ndata["h"]
192 F.backward(h2, F.ones((10, D)) * 2)
193 assert F.array_equal(
194 F.grad(h1)[:, 0],
195 F.tensor([2.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 2.0]),
196 )
197 assert F.array_equal(F.grad(hh)[:, 0], F.tensor([2.0, 2.0, 2.0]))
198
199
200def _test_nx_conversion():

Callers

nothing calls this directly

Calls 3

gradMethod · 0.80
generate_graphFunction · 0.70
backwardMethod · 0.45

Tested by

no test coverage detected