MCPcopy
hub / github.com/hkust-nlp/ceval / generate

Method generate

code/evaluator_series/evaluators/llama.py:94–136  ·  view source on GitHub ↗
(
        self,
        prompt: str,
        max_gen_len: int,
        temperature: float = 0.8,
        top_p: float = 0.95,
        return_logits: bool = False
    )

Source from the content-addressed store, hash-verified

92 return prompt
93
94 def generate(
95 self,
96 prompt: str,
97 max_gen_len: int,
98 temperature: float = 0.8,
99 top_p: float = 0.95,
100 return_logits: bool = False
101 ) -> List[str]:
102 params = self.model.params
103 prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
104 prompt_size = len(prompt_tokens)
105 total_len = min(params.max_seq_len, max_gen_len + prompt_size)
106
107 tokens = torch.full(
108 (1, total_len), self.tokenizer.pad_id).cuda().long()
109 tokens[0, : prompt_size] = torch.tensor(prompt_tokens).long()
110 input_text_mask = tokens != self.tokenizer.pad_id
111 prev_pos = 0
112 if return_logits:
113 return self.model.forward(tokens[:, :prompt_size], 0)
114 for cur_pos in range(prompt_size, total_len):
115 logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
116 if temperature > 0:
117 probs = torch.softmax(logits / temperature, dim=-1)
118 next_token = sample_top_p(probs, top_p)
119 else:
120 next_token = torch.argmax(logits, dim=-1)
121 next_token = next_token.reshape(-1)
122 next_token = torch.where(
123 input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
124 )
125 tokens[:, cur_pos] = next_token
126 prev_pos = cur_pos
127
128 decoded = []
129 for _, t in enumerate(tokens.tolist()):
130 t = t[: prompt_size + max_gen_len]
131 try:
132 t = t[: t.index(self.tokenizer.eos_id)]
133 except ValueError:
134 pass
135 decoded.append(self.tokenizer.decode(t))
136 return decoded
137
138 def extract_model_answer(self,text, a,b,c,d):
139 option_str=re.escape('A. '+a+'\nB. '+b+'\nC. '+c+'\nD. '+d)

Callers 3

eval_subjectMethod · 0.95
generate_distMethod · 0.80
eval_subjectMethod · 0.80

Calls 1

sample_top_pFunction · 0.85

Tested by

no test coverage detected