(
comparison,
x,
y,
err_msg="",
verbose=True,
header="",
precision=6,
equal_nan=True,
equal_inf=True,
*,
strict=False,
)
| 549 | |
| 550 | |
| 551 | def assert_array_compare( |
| 552 | comparison, |
| 553 | x, |
| 554 | y, |
| 555 | err_msg="", |
| 556 | verbose=True, |
| 557 | header="", |
| 558 | precision=6, |
| 559 | equal_nan=True, |
| 560 | equal_inf=True, |
| 561 | *, |
| 562 | strict=False, |
| 563 | ): |
| 564 | __tracebackhide__ = True # Hide traceback for py.test |
| 565 | from torch._numpy import all, array, asarray, bool_, inf, isnan, max |
| 566 | |
| 567 | x = asarray(x) |
| 568 | y = asarray(y) |
| 569 | |
| 570 | def array2string(a): |
| 571 | return str(a) |
| 572 | |
| 573 | # original array for output formatting |
| 574 | ox, oy = x, y |
| 575 | |
| 576 | def func_assert_same_pos(x, y, func=isnan, hasval="nan"): |
| 577 | """Handling nan/inf. |
| 578 | |
| 579 | Combine results of running func on x and y, checking that they are True |
| 580 | at the same locations. |
| 581 | |
| 582 | """ |
| 583 | __tracebackhide__ = True # Hide traceback for py.test |
| 584 | x_id = func(x) |
| 585 | y_id = func(y) |
| 586 | # We include work-arounds here to handle three types of slightly |
| 587 | # pathological ndarray subclasses: |
| 588 | # (1) all() on `masked` array scalars can return masked arrays, so we |
| 589 | # use != True |
| 590 | # (2) __eq__ on some ndarray subclasses returns Python booleans |
| 591 | # instead of element-wise comparisons, so we cast to bool_() and |
| 592 | # use isinstance(..., bool) checks |
| 593 | # (3) subclasses with bare-bones __array_function__ implementations may |
| 594 | # not implement np.all(), so favor using the .all() method |
| 595 | # We are not committed to supporting such subclasses, but it's nice to |
| 596 | # support them if possible. |
| 597 | if (x_id == y_id).all().item() is not True: |
| 598 | msg = build_err_msg( |
| 599 | [x, y], |
| 600 | err_msg + "\nx and y %s location mismatch:" % (hasval), |
| 601 | verbose=verbose, |
| 602 | header=header, |
| 603 | names=("x", "y"), |
| 604 | precision=precision, |
| 605 | ) |
| 606 | raise AssertionError(msg) |
| 607 | # If there is a scalar, then here we know the array has the same |
| 608 | # flag as it everywhere, so we should return the scalar flag. |
no test coverage detected
searching dependent graphs…