Evaluate if the response matches the ground truth based on category. Args: response: Model's response ground_truth: Correct answer category: Problem category (gsm8k, mmlu_math, boolq, aqua_rat) question: Original question text, needed for MMLU evaluation
(response: str, ground_truth: str, category: str, question: str = None)
| 204 | return False, -1 |
| 205 | |
| 206 | def evaluate_response(response: str, ground_truth: str, category: str, question: str = None) -> bool: |
| 207 | """ |
| 208 | Evaluate if the response matches the ground truth based on category. |
| 209 | |
| 210 | Args: |
| 211 | response: Model's response |
| 212 | ground_truth: Correct answer |
| 213 | category: Problem category (gsm8k, mmlu_math, boolq, aqua_rat) |
| 214 | question: Original question text, needed for MMLU evaluation |
| 215 | |
| 216 | Returns: |
| 217 | bool: Whether the response is correct |
| 218 | """ |
| 219 | if not response or not ground_truth: |
| 220 | return False |
| 221 | |
| 222 | # First, remove any thinking blocks |
| 223 | response = remove_thinking_blocks(response) |
| 224 | |
| 225 | if category == "gsm8k": |
| 226 | # Extract numerical answers after ### and compare |
| 227 | response_num = extract_gsm8k_answer(response) |
| 228 | ground_truth_num = extract_gsm8k_answer(ground_truth) |
| 229 | |
| 230 | if response_num is None or ground_truth_num is None: |
| 231 | return False |
| 232 | |
| 233 | # Compare with small tolerance for floating point |
| 234 | return abs(response_num - ground_truth_num) < 1e-6 |
| 235 | elif category == "mmlu_math": |
| 236 | # Special handling for MMLU-math multiple choice questions |
| 237 | response_clean = response.strip().lower() |
| 238 | ground_truth_clean = ground_truth.strip().lower() |
| 239 | |
| 240 | # Case 1: Exact match of answer text |
| 241 | if response_clean == ground_truth_clean: |
| 242 | logger.debug("Exact text match") |
| 243 | return True |
| 244 | |
| 245 | # For other cases, we need to find what index corresponds to the ground truth |
| 246 | if question: |
| 247 | correct_index = extract_choice_index_from_question(question, ground_truth) |
| 248 | |
| 249 | if correct_index >= 0: |
| 250 | # Case 2: Check if response is just the digit (most common LLM response for indices) |
| 251 | is_numeric, value = is_numeric_only_response(response) |
| 252 | if is_numeric and value == correct_index: |
| 253 | logger.debug(f"Numeric match: response '{response}' -> {value} matches index {correct_index}") |
| 254 | return True |
| 255 | |
| 256 | # Case 3: Check if response is "index. answer" |
| 257 | if re.search(fr"{correct_index}\s*\.\s*{re.escape(ground_truth_clean)}", response_clean): |
| 258 | logger.debug("Pattern match for 'index. answer'") |
| 259 | return True |
| 260 | |
| 261 | # Case 4: Check if response contains both the index and the answer text |
| 262 | if str(correct_index) in response_clean and ground_truth_clean in response_clean: |
| 263 | logger.debug("Contains both index and answer") |
no test coverage detected