(example: Dict[str, Any], data_name)
| 239 | |
| 240 | |
| 241 | def parse_ground_truth(example: Dict[str, Any], data_name): |
| 242 | if 'gt_cot' in example: |
| 243 | return example['gt_cot'], strip_string(example['gt']) |
| 244 | |
| 245 | # parse ground truth |
| 246 | if data_name in ["math", 'ocw']: |
| 247 | gt_cot = example['solution'] |
| 248 | gt_ans = extract_answer(gt_cot) |
| 249 | elif data_name == "gsm8k": |
| 250 | gt_cot, gt_ans = example['answer'].split("####") |
| 251 | elif data_name == "gsm-hard": |
| 252 | gt_cot, gt_ans = example['code'], example['target'] |
| 253 | elif data_name == "svamp": |
| 254 | gt_cot, gt_ans = example['Equation'], example['Answer'] |
| 255 | elif data_name == "asdiv": |
| 256 | gt_cot = example['formula'] |
| 257 | gt_ans = re.sub(r"\(.*?\)", "", example['answer']) |
| 258 | elif data_name == "mawps": |
| 259 | gt_cot, gt_ans = None, example['target'] |
| 260 | elif data_name == "tabmwp": |
| 261 | gt_cot = example['solution'] |
| 262 | gt_ans = example['answer'] |
| 263 | if example['ans_type'] in ['integer_number', 'decimal_number']: |
| 264 | if '/' in gt_ans: |
| 265 | gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1]) |
| 266 | elif ',' in gt_ans: |
| 267 | gt_ans = float(gt_ans.replace(',', '')) |
| 268 | elif '%' in gt_ans: |
| 269 | gt_ans = float(gt_ans.split('%')[0]) / 100 |
| 270 | else: |
| 271 | gt_ans = float(gt_ans) |
| 272 | elif data_name == "bbh": |
| 273 | gt_cot, gt_ans = None, example['target'] |
| 274 | else: |
| 275 | raise NotImplementedError(data_name) |
| 276 | # post process |
| 277 | gt_cot = str(gt_cot).strip() |
| 278 | gt_ans = strip_string(gt_ans) |
| 279 | return gt_cot, gt_ans |
| 280 | |
| 281 | |
| 282 | def parse_question(example, data_name): |
no test coverage detected