MCPcopy
hub / github.com/dmlc/dgl / test_backward

Function test_backward

tests/python/common/test_heterograph.py:2048–2073  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

2046
2047@parametrize_idtype
2048def test_backward(idtype):
2049 g = create_test_heterograph(idtype)
2050 x = F.randn((3, 5))
2051 F.attach_grad(x)
2052 g.nodes["user"].data["h"] = x
2053 with F.record_grad():
2054 g.multi_update_all(
2055 {
2056 "plays": (fn.copy_u("h", "m"), fn.sum("m", "y")),
2057 "wishes": (fn.copy_u("h", "m"), fn.sum("m", "y")),
2058 },
2059 "sum",
2060 )
2061 y = g.nodes["game"].data["y"]
2062 F.backward(y, F.ones(y.shape))
2063 print(F.grad(x))
2064 assert F.array_equal(
2065 F.grad(x),
2066 F.tensor(
2067 [
2068 [2.0, 2.0, 2.0, 2.0, 2.0],
2069 [2.0, 2.0, 2.0, 2.0, 2.0],
2070 [2.0, 2.0, 2.0, 2.0, 2.0],
2071 ]
2072 ),
2073 )
2074
2075
2076@parametrize_idtype

Callers

nothing calls this directly

Calls 4

multi_update_allMethod · 0.80
gradMethod · 0.80
create_test_heterographFunction · 0.70
backwardMethod · 0.45

Tested by

no test coverage detected