(idtype)
| 2046 | |
| 2047 | @parametrize_idtype |
| 2048 | def test_backward(idtype): |
| 2049 | g = create_test_heterograph(idtype) |
| 2050 | x = F.randn((3, 5)) |
| 2051 | F.attach_grad(x) |
| 2052 | g.nodes["user"].data["h"] = x |
| 2053 | with F.record_grad(): |
| 2054 | g.multi_update_all( |
| 2055 | { |
| 2056 | "plays": (fn.copy_u("h", "m"), fn.sum("m", "y")), |
| 2057 | "wishes": (fn.copy_u("h", "m"), fn.sum("m", "y")), |
| 2058 | }, |
| 2059 | "sum", |
| 2060 | ) |
| 2061 | y = g.nodes["game"].data["y"] |
| 2062 | F.backward(y, F.ones(y.shape)) |
| 2063 | print(F.grad(x)) |
| 2064 | assert F.array_equal( |
| 2065 | F.grad(x), |
| 2066 | F.tensor( |
| 2067 | [ |
| 2068 | [2.0, 2.0, 2.0, 2.0, 2.0], |
| 2069 | [2.0, 2.0, 2.0, 2.0, 2.0], |
| 2070 | [2.0, 2.0, 2.0, 2.0, 2.0], |
| 2071 | ] |
| 2072 | ), |
| 2073 | ) |
| 2074 | |
| 2075 | |
| 2076 | @parametrize_idtype |
nothing calls this directly
no test coverage detected