| 81 | |
| 82 | |
| 83 | def decode_bertqa_output(inputs_text, hf_tokenizer, |
| 84 | start_logits: Tuple[torch.Tensor], |
| 85 | end_logits: Tuple[torch.Tensor]): |
| 86 | question, context = inputs_text['text'], inputs_text['text_pair'] |
| 87 | assert len(context) == len(question) |
| 88 | batch_size = len(context) |
| 89 | |
| 90 | # regenerate inputs_ids because it is flatten for remove_input_padding=True |
| 91 | inputs = hf_tokenizer(**inputs_text, padding=True, return_tensors='pt') |
| 92 | inputs_ids = inputs['input_ids'] |
| 93 | answer_start_index = [logit.argmax(dim=0) for logit in start_logits] |
| 94 | answer_end_index = [logit.argmax(dim=0) for logit in end_logits] |
| 95 | decode_answer = [] |
| 96 | for i in range(batch_size): |
| 97 | predict_answer_tokens = inputs_ids[ |
| 98 | i, answer_start_index[i]:answer_end_index[i] + 1] |
| 99 | predict_text = hf_tokenizer.decode(predict_answer_tokens, |
| 100 | skip_special_tokens=True) |
| 101 | decode_answer.append(predict_text) |
| 102 | return decode_answer |
| 103 | |
| 104 | |
| 105 | def compare_bertqa_result(inputs_text, res_answers, ref_answers): |