| 39 | return model, tokenizer |
| 40 | |
| 41 | def parse_answer_file(answer_file): |
| 42 | lines = open(answer_file, 'r').readlines() |
| 43 | |
| 44 | accuracy = 0 |
| 45 | last_number = 0 |
| 46 | should_find_answer = True |
| 47 | should_find_reference_answer = False |
| 48 | |
| 49 | for i, l in enumerate(lines): |
| 50 | try: |
| 51 | if should_find_answer: |
| 52 | last_number = re.findall(r'\d+', l)[-1] |
| 53 | except: |
| 54 | pass |
| 55 | |
| 56 | if should_find_reference_answer and l.startswith('####'): |
| 57 | reference_answer = l.split('####')[1].strip() |
| 58 | if reference_answer == last_number: |
| 59 | accuracy += 1 |
| 60 | elif l.startswith('===== CASE'): |
| 61 | should_find_answer = True |
| 62 | should_find_reference_answer = False |
| 63 | elif l.startswith('Reference Answer'): |
| 64 | should_find_answer = False |
| 65 | should_find_reference_answer = True |
| 66 | |
| 67 | print('Accuracy: ', accuracy / len(gsm8k_test['question']) * 100) |
| 68 | |
| 69 | def main(args): |
| 70 | |