MCPcopy
hub / github.com/meta-pytorch/captum / test_gradient_basic_2

Method test_gradient_basic_2

tests/utils/test_gradient.py:61–68  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

59 assertTensorAlmostEqual(self, input.grad, [[9.0]], delta=0.0, mode="max")
60
61 def test_gradient_basic_2(self) -> None:
62 model = BasicModel()
63 input = torch.tensor([[-3.0]], requires_grad=True)
64 input.grad = torch.tensor([[14.0]])
65 grads = compute_gradients(model, input)[0]
66 assertTensorAlmostEqual(self, grads, [[1.0]], delta=0.01, mode="max")
67 # Verify grad attribute is not altered
68 assertTensorAlmostEqual(self, input.grad, [[14.0]], delta=0.0, mode="max")
69
70 def test_gradient_multiinput(self) -> None:
71 model = BasicModel6_MultiTensor()

Callers

nothing calls this directly

Calls 3

BasicModelClass · 0.90
compute_gradientsFunction · 0.90
assertTensorAlmostEqualFunction · 0.90

Tested by

no test coverage detected