(
test: unittest.TestCase,
# pyre-fixme[2]: Parameter must be annotated.
actual,
# pyre-fixme[2]: Parameter must be annotated.
expected,
delta: float = 0.0001,
mode: str = "sum",
)
| 25 | |
| 26 | |
| 27 | def assertTensorAlmostEqual( |
| 28 | test: unittest.TestCase, |
| 29 | # pyre-fixme[2]: Parameter must be annotated. |
| 30 | actual, |
| 31 | # pyre-fixme[2]: Parameter must be annotated. |
| 32 | expected, |
| 33 | delta: float = 0.0001, |
| 34 | mode: str = "sum", |
| 35 | ) -> None: |
| 36 | assert isinstance(actual, torch.Tensor), ( |
| 37 | "Actual parameter given for " "comparison must be a tensor." |
| 38 | ) |
| 39 | if not isinstance(expected, torch.Tensor): |
| 40 | expected = torch.tensor(expected, dtype=actual.dtype) |
| 41 | assert ( |
| 42 | actual.shape == expected.shape |
| 43 | ), f"Expected tensor with shape: {expected.shape}. Actual shape {actual.shape}." |
| 44 | actual = actual.cpu() |
| 45 | expected = expected.cpu() |
| 46 | if mode == "sum": |
| 47 | test.assertAlmostEqual( |
| 48 | torch.sum(torch.abs(actual - expected)).item(), 0.0, delta=delta |
| 49 | ) |
| 50 | elif mode == "max": |
| 51 | # if both tensors are empty, they are equal but there is no max |
| 52 | if actual.numel() == expected.numel() == 0: |
| 53 | return |
| 54 | |
| 55 | if actual.size() == torch.Size([]): |
| 56 | test.assertAlmostEqual( |
| 57 | torch.max(torch.abs(actual - expected)).item(), 0.0, delta=delta |
| 58 | ) |
| 59 | else: |
| 60 | for index, (input, ref) in enumerate(zip(actual, expected)): |
| 61 | almost_equal = abs(input - ref) <= delta |
| 62 | if hasattr(almost_equal, "__iter__"): |
| 63 | almost_equal = almost_equal.all() |
| 64 | assert ( |
| 65 | almost_equal |
| 66 | ), "Values at index {}, {} and {}, differ more than by {}".format( |
| 67 | index, input, ref, delta |
| 68 | ) |
| 69 | else: |
| 70 | raise ValueError("Mode for assertion comparison must be one of `max` or `sum`.") |
| 71 | |
| 72 | |
| 73 | def assertTensorTuplesAlmostEqual( |
no outgoing calls