()
| 37 | |
| 38 | |
| 39 | def main(): |
| 40 | args = initialize(extra_args_provider=add_evaluation_specific_args) |
| 41 | args.task = find_all_tasks(args.task) |
| 42 | |
| 43 | task_classes = [] |
| 44 | print_rank_0("> Loading task configs") |
| 45 | for task_config_path in args.task: |
| 46 | config = BaseConfig.from_yaml_file(task_config_path) |
| 47 | if config.module: |
| 48 | path = ".".join(config.module.split(".")[:-1]) |
| 49 | module = importlib.import_module(path) |
| 50 | class_name = config.module.split(".")[-1] |
| 51 | task_class = getattr(module, class_name) |
| 52 | task_classes.append(task_class) |
| 53 | else: |
| 54 | task_classes.append(DEFAULT_CLASS[config.type]) |
| 55 | print_rank_0(f" Task {config.name} loaded from config {task_config_path}") |
| 56 | print_rank_0(f"> Successfully load {len(task_classes)} task{'s' if len(task_classes) > 1 else ''}") |
| 57 | |
| 58 | model, tokenizer = initialize_model_and_tokenizer(args) |
| 59 | model = ModelForEvaluation(model) |
| 60 | |
| 61 | start = time.time() |
| 62 | evaluate_all_tasks(args.data_path, model, tokenizer, args.task, task_classes) |
| 63 | print_rank_0(f"Finish {len(task_classes)} task{'s' if len(task_classes) > 1 else ''} in {time.time() - start:.1f}s") |
| 64 | |
| 65 | |
| 66 | if __name__ == "__main__": |
no test coverage detected