(args)
| 11 | choices = ["A", "B", "C", "D"] |
| 12 | |
| 13 | def main(args): |
| 14 | |
| 15 | if "turbo" in args.model_name or "gpt-4" in args.model_name: |
| 16 | evaluator=ChatGPT_Evaluator( |
| 17 | choices=choices, |
| 18 | k=args.ntrain, |
| 19 | api_key=args.openai_key, |
| 20 | model_name=args.model_name |
| 21 | ) |
| 22 | elif "moss" in args.model_name: |
| 23 | evaluator=Moss_Evaluator( |
| 24 | choices=choices, |
| 25 | k=args.ntrain, |
| 26 | model_name=args.model_name |
| 27 | ) |
| 28 | elif "chatglm" in args.model_name: |
| 29 | if args.cuda_device: |
| 30 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_device |
| 31 | device = torch.device("cuda") |
| 32 | evaluator=ChatGLM_Evaluator( |
| 33 | choices=choices, |
| 34 | k=args.ntrain, |
| 35 | model_name=args.model_name, |
| 36 | device=device |
| 37 | ) |
| 38 | elif "minimax" in args.model_name: |
| 39 | evaluator=MiniMax_Evaluator( |
| 40 | choices=choices, |
| 41 | k=args.ntrain, |
| 42 | group_id=args.minimax_group_id, |
| 43 | api_key=args.minimax_key, |
| 44 | model_name=args.model_name |
| 45 | ) |
| 46 | else: |
| 47 | print("Unknown model name") |
| 48 | return -1 |
| 49 | |
| 50 | subject_name=args.subject |
| 51 | if not os.path.exists(r"logs"): |
| 52 | os.mkdir(r"logs") |
| 53 | run_date=time.strftime('%Y-%m-%d_%H-%M-%S',time.localtime(time.time())) |
| 54 | save_result_dir=os.path.join(r"logs",f"{args.model_name}_{run_date}") |
| 55 | os.mkdir(save_result_dir) |
| 56 | print(subject_name) |
| 57 | val_file_path=os.path.join('data/val',f'{subject_name}_val.csv') |
| 58 | val_df=pd.read_csv(val_file_path) |
| 59 | if args.few_shot: |
| 60 | dev_file_path=os.path.join('data/dev',f'{subject_name}_dev.csv') |
| 61 | dev_df=pd.read_csv(dev_file_path) |
| 62 | correct_ratio = evaluator.eval_subject(subject_name, val_df, dev_df, few_shot=args.few_shot,save_result_dir=save_result_dir,cot=args.cot) |
| 63 | else: |
| 64 | correct_ratio = evaluator.eval_subject(subject_name, val_df, few_shot=args.few_shot,save_result_dir=save_result_dir) |
| 65 | print("Acc:",correct_ratio) |
| 66 | |
| 67 | |
| 68 | if __name__ == "__main__": |
no test coverage detected