MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / decode_bertqa_output

Function decode_bertqa_output

examples/models/core/bert/utils.py:83–102  ·  view source on GitHub ↗
(inputs_text, hf_tokenizer,
                         start_logits: Tuple[torch.Tensor],
                         end_logits: Tuple[torch.Tensor])

Source from the content-addressed store, hash-verified

81
82
83def decode_bertqa_output(inputs_text, hf_tokenizer,
84 start_logits: Tuple[torch.Tensor],
85 end_logits: Tuple[torch.Tensor]):
86 question, context = inputs_text['text'], inputs_text['text_pair']
87 assert len(context) == len(question)
88 batch_size = len(context)
89
90 # regenerate inputs_ids because it is flatten for remove_input_padding=True
91 inputs = hf_tokenizer(**inputs_text, padding=True, return_tensors='pt')
92 inputs_ids = inputs['input_ids']
93 answer_start_index = [logit.argmax(dim=0) for logit in start_logits]
94 answer_end_index = [logit.argmax(dim=0) for logit in end_logits]
95 decode_answer = []
96 for i in range(batch_size):
97 predict_answer_tokens = inputs_ids[
98 i, answer_start_index[i]:answer_end_index[i] + 1]
99 predict_text = hf_tokenizer.decode(predict_answer_tokens,
100 skip_special_tokens=True)
101 decode_answer.append(predict_text)
102 return decode_answer
103
104
105def compare_bertqa_result(inputs_text, res_answers, ref_answers):

Callers 1

run.pyFile · 0.90

Calls 2

decodeMethod · 0.45
appendMethod · 0.45

Tested by

no test coverage detected