Checks if the answer uses all numbers exactly once and evaluates to the target
(
response: str, numbers: List[int] = None, target: int = None
)
| 111 | |
| 112 | |
| 113 | def answer_reward_function( |
| 114 | response: str, numbers: List[int] = None, target: int = None |
| 115 | ) -> float: |
| 116 | """ |
| 117 | Checks if the answer uses all numbers exactly once and evaluates to the target |
| 118 | """ |
| 119 | answer_regex = r"<answer>(.*?)<\/answer>" |
| 120 | answer_match = re.search(answer_regex, response, re.DOTALL) |
| 121 | if not answer_match: |
| 122 | return 0.0 |
| 123 | |
| 124 | answer_content = answer_match.group(1) |
| 125 | if not answer_content: |
| 126 | return 0.0 |
| 127 | |
| 128 | allowed_chars = r"^[0-9+\-*/() ]+$" |
| 129 | if not re.match(allowed_chars, answer_content): |
| 130 | return 0.0 |
| 131 | |
| 132 | # Check if the answer uses all numbers exactly once |
| 133 | used_numbers = [int(n) for n in re.findall(r"\d+", answer_content)] |
| 134 | if sorted(used_numbers) != sorted(numbers): |
| 135 | return 0.0 |
| 136 | |
| 137 | # Check if the answer evaluates to the target |
| 138 | try: |
| 139 | result = eval(answer_content, {"__builtins__": None}, {}) |
| 140 | if abs(float(result) - float(target)) < 1e-5: |
| 141 | return 1.0 |
| 142 | except: |
| 143 | pass |
| 144 | |
| 145 | return 0.0 |
| 146 | |
| 147 | |
| 148 | def reward_function( |