(self, subject_name, test_df, dev_df=None, few_shot=False, cot=False, save_result_dir=None)
| 35 | continue |
| 36 | |
| 37 | def eval_subject(self, subject_name, test_df, dev_df=None, few_shot=False, cot=False, save_result_dir=None): |
| 38 | correct_num = 0 |
| 39 | if save_result_dir: |
| 40 | result = [] |
| 41 | score = [] |
| 42 | if few_shot: |
| 43 | prompt_message = self.generate_few_shot_prompt(subject_name, dev_df, cot=cot) |
| 44 | else: |
| 45 | prompt_message = [] |
| 46 | answers = list(test_df['answer']) |
| 47 | for row_index, row in tqdm(test_df.iterrows()): |
| 48 | question = self.format_example(row, include_answer=False, cot=cot) |
| 49 | message = prompt_message + question |
| 50 | response = self.query(subject_name, message) |
| 51 | if cot: |
| 52 | ans, direct_extract = self.extract_cot_answer(row, response) |
| 53 | if ans == answers[row_index]: |
| 54 | correct_num += 1 |
| 55 | correct = 1 |
| 56 | else: |
| 57 | correct = 0 |
| 58 | else: |
| 59 | if response and (response[0] == answers[row_index]): |
| 60 | correct_num += 1 |
| 61 | correct = 1 |
| 62 | else: |
| 63 | correct = 0 |
| 64 | if save_result_dir: |
| 65 | result.append(response) |
| 66 | score.append(correct) |
| 67 | correct_ratio = 100*correct_num/len(answers) |
| 68 | if save_result_dir: |
| 69 | test_df['model_output'] = result |
| 70 | test_df['correctness'] = score |
| 71 | test_df.to_csv(os.path.join(save_result_dir, f'{subject_name}_test.csv')) |
| 72 | return correct_ratio |
| 73 | |
| 74 | def create_message(self, text, t): |
| 75 | if t == 'user': |
no test coverage detected