MCPcopy
hub / github.com/meta-pytorch/captum / assertTensorAlmostEqual

Function assertTensorAlmostEqual

captum/testing/helpers/basic.py:27–70  ·  view source on GitHub ↗
(
    test: unittest.TestCase,
    # pyre-fixme[2]: Parameter must be annotated.
    actual,
    # pyre-fixme[2]: Parameter must be annotated.
    expected,
    delta: float = 0.0001,
    mode: str = "sum",
)

Source from the content-addressed store, hash-verified

25
26
27def assertTensorAlmostEqual(
28 test: unittest.TestCase,
29 # pyre-fixme[2]: Parameter must be annotated.
30 actual,
31 # pyre-fixme[2]: Parameter must be annotated.
32 expected,
33 delta: float = 0.0001,
34 mode: str = "sum",
35) -> None:
36 assert isinstance(actual, torch.Tensor), (
37 "Actual parameter given for " "comparison must be a tensor."
38 )
39 if not isinstance(expected, torch.Tensor):
40 expected = torch.tensor(expected, dtype=actual.dtype)
41 assert (
42 actual.shape == expected.shape
43 ), f"Expected tensor with shape: {expected.shape}. Actual shape {actual.shape}."
44 actual = actual.cpu()
45 expected = expected.cpu()
46 if mode == "sum":
47 test.assertAlmostEqual(
48 torch.sum(torch.abs(actual - expected)).item(), 0.0, delta=delta
49 )
50 elif mode == "max":
51 # if both tensors are empty, they are equal but there is no max
52 if actual.numel() == expected.numel() == 0:
53 return
54
55 if actual.size() == torch.Size([]):
56 test.assertAlmostEqual(
57 torch.max(torch.abs(actual - expected)).item(), 0.0, delta=delta
58 )
59 else:
60 for index, (input, ref) in enumerate(zip(actual, expected)):
61 almost_equal = abs(input - ref) <= delta
62 if hasattr(almost_equal, "__iter__"):
63 almost_equal = almost_equal.all()
64 assert (
65 almost_equal
66 ), "Values at index {}, {} and {}, differ more than by {}".format(
67 index, input, ref, delta
68 )
69 else:
70 raise ValueError("Mode for assertion comparison must be one of `max` or `sum`.")
71
72
73def assertTensorTuplesAlmostEqual(

Calls

no outgoing calls