| 71 | |
| 72 | |
| 73 | def assertTensorTuplesAlmostEqual( |
| 74 | test: unittest.TestCase, |
| 75 | # pyre-fixme[2]: Parameter must be annotated. |
| 76 | actual, |
| 77 | # pyre-fixme[2]: Parameter must be annotated. |
| 78 | expected, |
| 79 | delta: float = 0.0001, |
| 80 | mode: str = "sum", |
| 81 | ) -> None: |
| 82 | if isinstance(expected, tuple): |
| 83 | assert len(actual) == len( |
| 84 | expected |
| 85 | ), f"the length of actual {len(actual)} != expected {len(expected)}" |
| 86 | |
| 87 | for i in range(len(expected)): |
| 88 | assertTensorAlmostEqual(test, actual[i], expected[i], delta, mode) |
| 89 | else: |
| 90 | assertTensorAlmostEqual(test, actual, expected, delta, mode) |
| 91 | |
| 92 | |
| 93 | def assertTupleOfListOfTensorsAlmostEqual( |