Exact match of math if and only if: 1. numerical equal: both can convert to float and are equal 2. symbolic equal: both can convert to sympy expression and are equal
(prediction: Union[bool, float, str],
reference: Union[float, str],
include_percentage: bool = True,
is_close: bool = True,
timeout: bool = False,
)
| 20 | return False |
| 21 | |
| 22 | def math_equal(prediction: Union[bool, float, str], |
| 23 | reference: Union[float, str], |
| 24 | include_percentage: bool = True, |
| 25 | is_close: bool = True, |
| 26 | timeout: bool = False, |
| 27 | ) -> bool: |
| 28 | """ |
| 29 | Exact match of math if and only if: |
| 30 | 1. numerical equal: both can convert to float and are equal |
| 31 | 2. symbolic equal: both can convert to sympy expression and are equal |
| 32 | """ |
| 33 | try: # 1. numerical equal |
| 34 | if is_digit(prediction) and is_digit(reference): |
| 35 | prediction = float(str(prediction).replace(",", "")) |
| 36 | reference = float(str(reference).replace(",", "")) |
| 37 | # number questions |
| 38 | if include_percentage: |
| 39 | gt_result = [reference / 100, reference, reference * 100] |
| 40 | else: |
| 41 | gt_result = [reference] |
| 42 | for item in gt_result: |
| 43 | try: |
| 44 | if is_close: |
| 45 | if isclose(item, prediction, rel_tol=1e-4): |
| 46 | return True |
| 47 | else: |
| 48 | if item == prediction: |
| 49 | return True |
| 50 | except Exception: |
| 51 | continue |
| 52 | return False |
| 53 | except: |
| 54 | pass |
| 55 | |
| 56 | if not prediction and prediction not in [0, False]: |
| 57 | return False |
| 58 | |
| 59 | # 2. symbolic equal |
| 60 | reference = str(reference).strip() |
| 61 | prediction = str(prediction).strip() |
| 62 | |
| 63 | ## deal with [], (), {} |
| 64 | pred_str, ref_str = prediction, reference |
| 65 | if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \ |
| 66 | (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")): |
| 67 | pred_str = pred_str.strip("[]()") |
| 68 | ref_str = ref_str.strip("[]()") |
| 69 | for s in ['{', "}", "(", ")"]: |
| 70 | ref_str = ref_str.replace(s, "") |
| 71 | pred_str = pred_str.replace(s, "") |
| 72 | if pred_str == ref_str: |
| 73 | return True |
| 74 | |
| 75 | ## [a, b] vs. [c, d], return a==c and b==d |
| 76 | if (prediction.startswith("[") and prediction.endswith("]")) and (reference.startswith("[") and reference.endswith("]")) or \ |
| 77 | (prediction.startswith("(") and prediction.endswith(")")) and (reference.startswith("(") and reference.endswith(")")): |
| 78 | pred_parts = prediction[1:-1].split(",") |
| 79 | ref_parts = reference[1:-1].split(",") |
no test coverage detected