(sample_dir: str, dataset: str = "humaneval", debug_task: str = None)
| 148 | |
| 149 | |
| 150 | def script(sample_dir: str, dataset: str = "humaneval", debug_task: str = None): |
| 151 | assert dataset in ["humaneval", "mbpp"] |
| 152 | if dataset == "humaneval": |
| 153 | problems = get_human_eval_plus(noextreme=True) |
| 154 | dataset_hash = get_human_eval_plus_hash(noextreme=True) |
| 155 | expected_output = get_groundtruth(problems, dataset_hash, []) |
| 156 | elif dataset == "mbpp": |
| 157 | from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS |
| 158 | |
| 159 | problems = get_mbpp_plus(noextreme=True) |
| 160 | dataset_hash = get_mbpp_plus_hash(noextreme=True) |
| 161 | expected_output = get_groundtruth( |
| 162 | problems, |
| 163 | dataset_hash, |
| 164 | MBPP_OUTPUT_NOT_NONE_TASKS, |
| 165 | ) |
| 166 | |
| 167 | previous_solutions = None |
| 168 | |
| 169 | for task_id, task in tqdm(problems.items()): |
| 170 | if debug_task and task_id != debug_task: |
| 171 | continue |
| 172 | solutions = gather_solutions(sample_dir, task_id.replace("/", "_")) |
| 173 | solutions = deduplicate(solutions) |
| 174 | correct_solutions = test_solutions( |
| 175 | dataset, solutions, task, expected_output[task_id] |
| 176 | ) |
| 177 | |
| 178 | # clean solutions to remove print statements and format it |
| 179 | correct_solutions = [ |
| 180 | void_calls(solution, ["print"])[0] for solution in correct_solutions |
| 181 | ] |
| 182 | |
| 183 | # Assuming that the previous solutions are correct |
| 184 | if previous_solutions and task_id in previous_solutions: |
| 185 | correct_solutions = deduplicate( |
| 186 | correct_solutions + previous_solutions[task_id] |
| 187 | ) |
| 188 | with open("solutions.jsonl", "a+") as f: |
| 189 | f.write( |
| 190 | json.dumps({"task_id": task_id, "solution": correct_solutions}) + "\n" |
| 191 | ) |
| 192 | |
| 193 | |
| 194 | def main(): |
nothing calls this directly
no test coverage detected