(sample_dir: str, sample_eval_dir: str, dataset: str)
| 8 | |
| 9 | |
| 10 | def collect_sample_info(sample_dir: str, sample_eval_dir: str, dataset: str): |
| 11 | if os.path.exists(sample_dir) and len(os.listdir(sample_dir)) > 0: |
| 12 | # cache file exists |
| 13 | return |
| 14 | task_ids = get_task_ids(dataset) |
| 15 | assert os.path.exists(sample_eval_dir), "sample evaluation files missing" |
| 16 | os.makedirs(sample_dir, exist_ok=True) |
| 17 | kill_info = {task_id: {} for task_id in task_ids} |
| 18 | model_paths = os.listdir(sample_eval_dir) |
| 19 | for model_path in track(model_paths, description="Collecting sets..."): |
| 20 | if not model_path[-1].isdigit(): |
| 21 | continue |
| 22 | eval_json_path = os.path.join(sample_eval_dir, model_path, "eval_results.json") |
| 23 | if not os.path.exists(eval_json_path): |
| 24 | continue |
| 25 | with open(eval_json_path, "r") as f: |
| 26 | res = json.load(f)["eval"] |
| 27 | for task_id, v in res.items(): |
| 28 | if task_id not in task_ids: |
| 29 | continue |
| 30 | for i_code, (status, res_list) in enumerate(v["plus"]): |
| 31 | if status == "success": |
| 32 | continue |
| 33 | for i_test, res in enumerate(res_list): |
| 34 | test_id = f"plus_{i_test}" |
| 35 | if res == False: |
| 36 | if "_" in task_id: |
| 37 | task_id = task_id.replace("_", "/") |
| 38 | kill_info[task_id].setdefault(test_id, []).append( |
| 39 | (model_path, i_code) |
| 40 | ) |
| 41 | for task_id in task_ids: |
| 42 | path_task_id = to_path(task_id) |
| 43 | with open(os.path.join(sample_dir, f"{path_task_id}.pkl"), "wb") as f: |
| 44 | pickle.dump(kill_info[task_id], f) |
| 45 | |
| 46 | |
| 47 | if __name__ == "__main__": |
no test coverage detected