The answer will be considered correct if: (a) it normalizes to the same string as the ground truth answer OR (b) sympy can simplify the difference between the expressions to 0
(given_answer: str, ground_truth: str)
| 232 | |
| 233 | |
| 234 | def grade_answer(given_answer: str, ground_truth: str) -> bool: |
| 235 | """ |
| 236 | The answer will be considered correct if: |
| 237 | (a) it normalizes to the same string as the ground truth answer |
| 238 | OR |
| 239 | (b) sympy can simplify the difference between the expressions to 0 |
| 240 | """ |
| 241 | if given_answer is None: |
| 242 | return False |
| 243 | |
| 244 | ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth) |
| 245 | given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer) |
| 246 | |
| 247 | # be at least as lenient as mathd |
| 248 | if ground_truth_normalized_mathd == given_answer_normalized_mathd: |
| 249 | return True |
| 250 | |
| 251 | ground_truth_normalized = _normalize(ground_truth) |
| 252 | given_normalized = _normalize(given_answer) |
| 253 | |
| 254 | if ground_truth_normalized is None: |
| 255 | return False |
| 256 | |
| 257 | if ground_truth_normalized == given_normalized: |
| 258 | return True |
| 259 | |
| 260 | if len(given_normalized) == 0: |
| 261 | return False |
| 262 | |
| 263 | ground_truth_elems = split_tuple(ground_truth_normalized) |
| 264 | given_elems = split_tuple(given_normalized) |
| 265 | |
| 266 | if len(ground_truth_elems) > 1 and ( |
| 267 | ground_truth_normalized[0] != given_normalized[0] |
| 268 | or ground_truth_normalized[-1] != given_normalized[-1] |
| 269 | ): |
| 270 | is_correct = False |
| 271 | elif len(ground_truth_elems) != len(given_elems): |
| 272 | is_correct = False |
| 273 | else: |
| 274 | for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): |
| 275 | if _is_frac(ground_truth_elem) and _is_frac(given_elem): |
| 276 | # if fractions aren't reduced, then shouldn't be marked as correct |
| 277 | # so, we don't want to allow sympy.simplify in this case |
| 278 | is_correct = ground_truth_elem == given_elem |
| 279 | elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): |
| 280 | # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) |
| 281 | is_correct = False |
| 282 | else: |
| 283 | is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) |
| 284 | if not is_correct: |
| 285 | break |
| 286 | |
| 287 | return is_correct |
no test coverage detected