()
| 40 | |
| 41 | |
| 42 | def main(): |
| 43 | args = parse_args() |
| 44 | np.random.seed(args.seed) |
| 45 | |
| 46 | if args.include_path is not None: |
| 47 | eval_logger.info(f"Including path: {args.include_path}") |
| 48 | |
| 49 | task_manager = TaskManager(include_path=args.include_path) |
| 50 | |
| 51 | if args.tasks == "all_tasks": |
| 52 | _task_names = task_manager.all_tasks |
| 53 | else: |
| 54 | _task_names = cast("list[str]", args.tasks.split(",")) |
| 55 | _res = task_manager.load(_task_names) |
| 56 | task_dicts = _res["tasks"].values() |
| 57 | |
| 58 | os.makedirs(args.output_base_path, exist_ok=True) |
| 59 | for task in task_dicts: |
| 60 | task_name = task.config.task |
| 61 | rnd = random.Random() |
| 62 | rnd.seed(args.seed) |
| 63 | |
| 64 | iters = [] |
| 65 | |
| 66 | for set in args.sets.split(","): |
| 67 | docs = None |
| 68 | if set == "train" and task.has_training_docs(): |
| 69 | docs = task.training_docs() |
| 70 | if set == "val" and task.has_validation_docs(): |
| 71 | docs = task.validation_docs() |
| 72 | if set == "test" and task.has_test_docs(): |
| 73 | docs = task.test_docs() |
| 74 | if docs is not None: |
| 75 | iters.append(docs) |
| 76 | |
| 77 | if len(iters) == 0: |
| 78 | raise ValueError( |
| 79 | f"Passed --sets '{args.sets}' but this task has no splits which match. Please specify a different --sets value." |
| 80 | ) |
| 81 | |
| 82 | docs = join_iters(iters) |
| 83 | |
| 84 | with open( |
| 85 | os.path.join(args.output_base_path, task_name), # type: ignore |
| 86 | "w", |
| 87 | encoding="utf8", |
| 88 | ) as f: |
| 89 | for i, doc in ( |
| 90 | zip(range(args.num_examples), docs, strict=False) |
| 91 | if args.num_examples > 0 |
| 92 | else enumerate(docs) |
| 93 | ): |
| 94 | f.write(EXAMPLE_DIVIDER.format(i=i)) |
| 95 | ctx = task.fewshot_context( |
| 96 | doc=doc, |
| 97 | num_fewshot=args.num_fewshot, |
| 98 | ) |
| 99 | f.write(ctx + "\n") |
no test coverage detected