MCPcopy
hub / github.com/huggingface/diffusers / assert_tensors_close

Function assert_tensors_close

tests/testing_utils.py:135–185  ·  view source on GitHub ↗

Assert that two tensors are close within tolerance. Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected| Provides concise, actionable error messages without dumping full tensors. Args: actual: The actual tensor from the computation.

(
    actual: "torch.Tensor",
    expected: "torch.Tensor",
    atol: float = 1e-5,
    rtol: float = 1e-5,
    msg: str = "",
)

Source from the content-addressed store, hash-verified

133
134
135def assert_tensors_close(
136 actual: "torch.Tensor",
137 expected: "torch.Tensor",
138 atol: float = 1e-5,
139 rtol: float = 1e-5,
140 msg: str = "",
141) -> None:
142 """
143 Assert that two tensors are close within tolerance.
144
145 Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected|
146 Provides concise, actionable error messages without dumping full tensors.
147
148 Args:
149 actual: The actual tensor from the computation.
150 expected: The expected tensor to compare against.
151 atol: Absolute tolerance.
152 rtol: Relative tolerance.
153 msg: Optional message prefix for the assertion error.
154
155 Raises:
156 AssertionError: If tensors have different shapes or values exceed tolerance.
157
158 Example:
159 >>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass")
160 """
161 if not is_torch_available():
162 raise ValueError("PyTorch needs to be installed to use this function.")
163
164 if actual.shape != expected.shape:
165 raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")
166
167 if not torch.allclose(actual, expected, atol=atol, rtol=rtol):
168 abs_diff = (actual - expected).abs()
169 max_diff = abs_diff.max().item()
170
171 flat_idx = abs_diff.argmax().item()
172 max_idx = tuple(idx.item() for idx in torch.unravel_index(torch.tensor(flat_idx), actual.shape))
173
174 threshold = atol + rtol * expected.abs()
175 mismatched = (abs_diff > threshold).sum().item()
176 total = actual.numel()
177
178 raise AssertionError(
179 f"{msg}\n"
180 f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n"
181 f" Max diff: {max_diff:.6e} at index {max_idx}\n"
182 f" Actual: {actual.flatten()[flat_idx].item():.6e}\n"
183 f" Expected: {expected.flatten()[flat_idx].item():.6e}\n"
184 f" atol: {atol:.6e}, rtol: {rtol:.6e}"
185 )
186
187
188def numpy_cosine_similarity_distance(a, b):

Calls 1

is_torch_availableFunction · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…