()
| 161 | |
| 162 | |
| 163 | def main(): |
| 164 | import argparse |
| 165 | |
| 166 | parser = argparse.ArgumentParser() |
| 167 | parser.add_argument("--debug-tasks", nargs="+", default=[], type=int) |
| 168 | |
| 169 | args = parser.parse_args() |
| 170 | console = Console() |
| 171 | |
| 172 | if hasattr(sys, "set_int_max_str_digits"): |
| 173 | sys.set_int_max_str_digits(int(10e8)) |
| 174 | |
| 175 | plus_problems = get_mbpp_plus(mini=False) |
| 176 | dataset_hash = get_mbpp_plus_hash() |
| 177 | |
| 178 | original_mbpp = get_mbpp() |
| 179 | |
| 180 | compatible_problems = {} |
| 181 | expected_outputs = get_groundtruth( |
| 182 | plus_problems, dataset_hash, MBPP_OUTPUT_NOT_NONE_TASKS |
| 183 | ) |
| 184 | |
| 185 | # debugging: monitoring test code size |
| 186 | id2bytes = {} |
| 187 | |
| 188 | n_workers = max(1, multiprocessing.cpu_count() // 4) |
| 189 | with ProcessPoolExecutor(max_workers=n_workers) as executor: |
| 190 | futures = [] |
| 191 | for task_id, plus_form in tqdm(plus_problems.items()): |
| 192 | # expected MBPP task_id is numbers directly |
| 193 | # i.e., "666" instead of "Mbpp/666" |
| 194 | # But in EvalPlus the task_id is "Mbpp/666" |
| 195 | task_id_int = int(task_id.split("/")[-1]) |
| 196 | if args.debug_tasks and task_id_int not in args.debug_tasks: |
| 197 | continue |
| 198 | |
| 199 | compatible_form = { |
| 200 | "task_id": task_id_int, |
| 201 | "code": plus_form["canonical_solution"], |
| 202 | "prompt": original_mbpp[str(task_id_int)]["prompt"], |
| 203 | "source_file": original_mbpp[str(task_id_int)]["source_file"], |
| 204 | "test_imports": original_mbpp[str(task_id_int)]["test_imports"], |
| 205 | "test_list": original_mbpp[str(task_id_int)]["test_list"], |
| 206 | } |
| 207 | compatible_problems[task_id_int] = compatible_form |
| 208 | |
| 209 | inputs = ( |
| 210 | plus_form["base_input"] + plus_form["plus_input"] |
| 211 | if len(plus_form["plus_input"]) > 0 |
| 212 | else plus_form["base_input"] |
| 213 | ) |
| 214 | results = ( |
| 215 | expected_outputs[task_id]["base"] + expected_outputs[task_id]["plus"] |
| 216 | ) |
| 217 | |
| 218 | inputs, results = deduplicate(inputs, results) |
| 219 | |
| 220 | assert len(inputs) == len(results) |
no test coverage detected