(solutions: str, output_profiled_solutions: str, pe_inputs: str = None)
| 13 | |
| 14 | |
| 15 | def script(solutions: str, output_profiled_solutions: str, pe_inputs: str = None): |
| 16 | assert solutions.endswith(".jsonl") |
| 17 | assert pe_inputs is None or pe_inputs.endswith(".jsonl") |
| 18 | assert output_profiled_solutions.endswith(".jsonl") |
| 19 | |
| 20 | evalplus = get_human_eval_plus(noextreme=True) |
| 21 | mbppplus = get_mbpp_plus(noextreme=True) |
| 22 | tasks = {**evalplus, **mbppplus} |
| 23 | |
| 24 | # assume each line's format is: { |
| 25 | # "task_id": task's id, |
| 26 | # "inputs": a list of inputs, |
| 27 | inputs_dict = None |
| 28 | |
| 29 | if pe_inputs is not None: |
| 30 | print("Loading performance-exercising inputs...") |
| 31 | with open(pe_inputs, "r") as f: |
| 32 | inputs_dict = { |
| 33 | task["task_id"]: task["inputs"] for l in f for task in [json.loads(l)] |
| 34 | } |
| 35 | |
| 36 | # Notably, the solutions are already validated and cleaned. |
| 37 | with open(solutions, "r") as f: |
| 38 | solutions = {} |
| 39 | for l in f: |
| 40 | solution = json.loads(l) |
| 41 | solutions[solution["task_id"]] = solution["solution"] |
| 42 | |
| 43 | for task_id, task in tqdm(tasks.items()): |
| 44 | if inputs_dict: |
| 45 | inputs = ( |
| 46 | mbpp_deserialize_inputs(task_id, inputs_dict[task_id]) |
| 47 | if "Mbpp/" in task_id |
| 48 | else inputs_dict[task_id] |
| 49 | ) |
| 50 | else: |
| 51 | inputs = task["base_input"] + list(task["plus_input"]) |
| 52 | |
| 53 | input_costs = [] |
| 54 | |
| 55 | if task_id.startswith("HumanEval"): |
| 56 | canonical_solution = task["prompt"] + task["canonical_solution"] |
| 57 | else: |
| 58 | canonical_solution = task["canonical_solution"] |
| 59 | |
| 60 | for inp in inputs: |
| 61 | costs = profile( |
| 62 | canonical_solution, |
| 63 | task["entry_point"], |
| 64 | [inp], |
| 65 | timeout_second_per_test=PERF_CURATE_TIMEOUT_SECOND, |
| 66 | ) |
| 67 | if are_profiles_broken(costs): |
| 68 | continue |
| 69 | input_costs.append((median(costs), inp)) |
| 70 | input_costs.sort(reverse=True, key=lambda x: x[0]) |
| 71 | |
| 72 | for _, pe_input in input_costs: |
nothing calls this directly
no test coverage detected