(
self,
prompt: str,
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
return_logits: bool = False
)
| 92 | return prompt |
| 93 | |
| 94 | def generate( |
| 95 | self, |
| 96 | prompt: str, |
| 97 | max_gen_len: int, |
| 98 | temperature: float = 0.8, |
| 99 | top_p: float = 0.95, |
| 100 | return_logits: bool = False |
| 101 | ) -> List[str]: |
| 102 | params = self.model.params |
| 103 | prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False) |
| 104 | prompt_size = len(prompt_tokens) |
| 105 | total_len = min(params.max_seq_len, max_gen_len + prompt_size) |
| 106 | |
| 107 | tokens = torch.full( |
| 108 | (1, total_len), self.tokenizer.pad_id).cuda().long() |
| 109 | tokens[0, : prompt_size] = torch.tensor(prompt_tokens).long() |
| 110 | input_text_mask = tokens != self.tokenizer.pad_id |
| 111 | prev_pos = 0 |
| 112 | if return_logits: |
| 113 | return self.model.forward(tokens[:, :prompt_size], 0) |
| 114 | for cur_pos in range(prompt_size, total_len): |
| 115 | logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) |
| 116 | if temperature > 0: |
| 117 | probs = torch.softmax(logits / temperature, dim=-1) |
| 118 | next_token = sample_top_p(probs, top_p) |
| 119 | else: |
| 120 | next_token = torch.argmax(logits, dim=-1) |
| 121 | next_token = next_token.reshape(-1) |
| 122 | next_token = torch.where( |
| 123 | input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token |
| 124 | ) |
| 125 | tokens[:, cur_pos] = next_token |
| 126 | prev_pos = cur_pos |
| 127 | |
| 128 | decoded = [] |
| 129 | for _, t in enumerate(tokens.tolist()): |
| 130 | t = t[: prompt_size + max_gen_len] |
| 131 | try: |
| 132 | t = t[: t.index(self.tokenizer.eos_id)] |
| 133 | except ValueError: |
| 134 | pass |
| 135 | decoded.append(self.tokenizer.decode(t)) |
| 136 | return decoded |
| 137 | |
| 138 | def extract_model_answer(self,text, a,b,c,d): |
| 139 | option_str=re.escape('A. '+a+'\nB. '+b+'\nC. '+c+'\nD. '+d) |
no test coverage detected