Assert that two tensors are close within tolerance. Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected| Provides concise, actionable error messages without dumping full tensors. Args: actual: The actual tensor from the computation.
(
actual: "torch.Tensor",
expected: "torch.Tensor",
atol: float = 1e-5,
rtol: float = 1e-5,
msg: str = "",
)
| 133 | |
| 134 | |
| 135 | def assert_tensors_close( |
| 136 | actual: "torch.Tensor", |
| 137 | expected: "torch.Tensor", |
| 138 | atol: float = 1e-5, |
| 139 | rtol: float = 1e-5, |
| 140 | msg: str = "", |
| 141 | ) -> None: |
| 142 | """ |
| 143 | Assert that two tensors are close within tolerance. |
| 144 | |
| 145 | Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected| |
| 146 | Provides concise, actionable error messages without dumping full tensors. |
| 147 | |
| 148 | Args: |
| 149 | actual: The actual tensor from the computation. |
| 150 | expected: The expected tensor to compare against. |
| 151 | atol: Absolute tolerance. |
| 152 | rtol: Relative tolerance. |
| 153 | msg: Optional message prefix for the assertion error. |
| 154 | |
| 155 | Raises: |
| 156 | AssertionError: If tensors have different shapes or values exceed tolerance. |
| 157 | |
| 158 | Example: |
| 159 | >>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass") |
| 160 | """ |
| 161 | if not is_torch_available(): |
| 162 | raise ValueError("PyTorch needs to be installed to use this function.") |
| 163 | |
| 164 | if actual.shape != expected.shape: |
| 165 | raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}") |
| 166 | |
| 167 | if not torch.allclose(actual, expected, atol=atol, rtol=rtol): |
| 168 | abs_diff = (actual - expected).abs() |
| 169 | max_diff = abs_diff.max().item() |
| 170 | |
| 171 | flat_idx = abs_diff.argmax().item() |
| 172 | max_idx = tuple(idx.item() for idx in torch.unravel_index(torch.tensor(flat_idx), actual.shape)) |
| 173 | |
| 174 | threshold = atol + rtol * expected.abs() |
| 175 | mismatched = (abs_diff > threshold).sum().item() |
| 176 | total = actual.numel() |
| 177 | |
| 178 | raise AssertionError( |
| 179 | f"{msg}\n" |
| 180 | f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n" |
| 181 | f" Max diff: {max_diff:.6e} at index {max_idx}\n" |
| 182 | f" Actual: {actual.flatten()[flat_idx].item():.6e}\n" |
| 183 | f" Expected: {expected.flatten()[flat_idx].item():.6e}\n" |
| 184 | f" atol: {atol:.6e}, rtol: {rtol:.6e}" |
| 185 | ) |
| 186 | |
| 187 | |
| 188 | def numpy_cosine_similarity_distance(a, b): |
no test coverage detected
searching dependent graphs…