(self, eval_file, **judge_kwargs)
| 46 | |
| 47 | # It returns a DataFrame |
| 48 | def evaluate(self, eval_file, **judge_kwargs): |
| 49 | from .utils.vqa_eval import hit_calculate, process_line |
| 50 | |
| 51 | data = load(eval_file) |
| 52 | dataset = self.dataset_name |
| 53 | assert 'answer' in data and 'prediction' in data |
| 54 | data['prediction'] = [str(x) for x in data['prediction']] |
| 55 | data['answer'] = [str(x) for x in data['answer']] |
| 56 | lt = len(data) |
| 57 | pool = mp.Pool(16) |
| 58 | lines = [data.iloc[i] for i in range(lt)] |
| 59 | if listinstr(['TextVQA'], dataset): |
| 60 | res = pool.map(partial(process_line, method='vqa_score'), lines) |
| 61 | elif listinstr(['ChartQA'], dataset): |
| 62 | res = pool.map(partial(process_line, method='relaxed_accuracy'), lines) |
| 63 | elif listinstr(['OCRVQA', 'GQA'], dataset): |
| 64 | res = pool.map(partial(process_line, method='accuracy'), lines) |
| 65 | elif listinstr(['DocVQA', 'InfoVQA'], dataset): |
| 66 | res = pool.map(partial(process_line, method='anls'), lines) |
| 67 | else: # default using vqa_score to calculate score |
| 68 | res = pool.map(process_line, lines) |
| 69 | hit = hit_calculate(res, dataset) |
| 70 | ret = dict() |
| 71 | if 'split' in data: |
| 72 | splits = set(data['split']) |
| 73 | for sp in splits: |
| 74 | sub = [r for l, r in zip(lines, res) if l['split'] == sp] |
| 75 | # [np.mean(x['match']) >= full_score_weight for x in sub] |
| 76 | hit = hit_calculate(sub, dataset) |
| 77 | ret[sp] = np.mean(hit) * 100 |
| 78 | sub = [r for l, r in zip(lines, res)] |
| 79 | hit = hit_calculate(sub, dataset) |
| 80 | ret['Overall'] = np.mean(hit) * 100 |
| 81 | else: |
| 82 | ret['Overall'] = np.mean(hit) * 100 |
| 83 | if 'category' in data: |
| 84 | cates = list(set(data['category'])) |
| 85 | cates.sort() |
| 86 | for c in cates: |
| 87 | sub = [r for l, r in zip(lines, res) if l['category'] == c] |
| 88 | # [np.mean(x['match']) >= full_score_weight for x in sub] |
| 89 | hit = hit_calculate(sub, dataset) |
| 90 | ret[c] = np.mean(hit) * 100 |
| 91 | ret = d2df(ret) |
| 92 | ret.round(2) |
| 93 | |
| 94 | suffix = eval_file.split('.')[-1] |
| 95 | result_file = eval_file.replace(f'.{suffix}', '_acc.csv') |
| 96 | dump(ret, result_file) |
| 97 | return ret |
| 98 | |
| 99 | |
| 100 | class VizWiz(ImageBaseDataset): |
nothing calls this directly
no test coverage detected