Checks if the response follows the format ... ...
(response: str, end_token: Optional[str] = None)
| 81 | |
| 82 | |
| 83 | def format_reward_function(response: str, end_token: Optional[str] = None) -> float: |
| 84 | """ |
| 85 | Checks if the response follows the format <think>...</think><answer>...</answer> |
| 86 | """ |
| 87 | # Strip end token if present |
| 88 | if end_token and response.endswith(end_token): |
| 89 | response = response[: -len(end_token)] |
| 90 | |
| 91 | think_regex = r"<think>.*?<\/think>" |
| 92 | answer_regex = r"<answer>.*?<\/answer>" |
| 93 | full_format_regex = r"^<think>.*?<\/think>\n<answer>.*?<\/answer>$" |
| 94 | |
| 95 | think_match = re.search(think_regex, response, re.DOTALL) |
| 96 | answer_match = re.search(answer_regex, response, re.DOTALL) |
| 97 | full_format_match = re.match(full_format_regex, response, re.DOTALL) |
| 98 | |
| 99 | if full_format_match: |
| 100 | return 1.0 |
| 101 | |
| 102 | reward = 0.0 |
| 103 | |
| 104 | if think_match: |
| 105 | reward += 0.1 |
| 106 | |
| 107 | if answer_match: |
| 108 | reward += 0.5 |
| 109 | |
| 110 | return reward |
| 111 | |
| 112 | |
| 113 | def answer_reward_function( |