(self)
| 6 | |
| 7 | class TestTensorGradient(unittest.TestCase): |
| 8 | def test_example(self): |
| 9 | x = Tensor.eye(3) |
| 10 | y = Tensor([[2.0,0,-2.0]]) |
| 11 | z = y.matmul(x).sum() |
| 12 | dx, dy = z.gradient(x, y) |
| 13 | self.assertListEqual(dx.tolist(), [[2.0, 2.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -2.0, -2.0]]) |
| 14 | self.assertListEqual(dy.tolist(), [[1.0, 1.0, 1.0]]) |
| 15 | |
| 16 | def test_zero_if_not_used(self): |
| 17 | x = Tensor([1.0, 2.0, 3.0]) |