Non-streaming batch generation that just returns the final token sequences. Returns a list of token sequences (list of lists of ints). Terminal tokens (assistant_end, bos) are not included in the results.
(self, tokens, num_samples=1, **kwargs)
| 280 | logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size) |
| 281 | |
| 282 | def generate_batch(self, tokens, num_samples=1, **kwargs): |
| 283 | """ |
| 284 | Non-streaming batch generation that just returns the final token sequences. |
| 285 | Returns a list of token sequences (list of lists of ints). |
| 286 | Terminal tokens (assistant_end, bos) are not included in the results. |
| 287 | """ |
| 288 | assistant_end = self.tokenizer.encode_special("<|assistant_end|>") |
| 289 | bos = self.tokenizer.get_bos_token_id() |
| 290 | results = [tokens.copy() for _ in range(num_samples)] |
| 291 | masks = [[0] * len(tokens) for _ in range(num_samples)] |
| 292 | completed = [False] * num_samples |
| 293 | for token_column, token_masks in self.generate(tokens, num_samples, **kwargs): |
| 294 | for i, (token, mask) in enumerate(zip(token_column, token_masks)): |
| 295 | if not completed[i]: |
| 296 | if token == assistant_end or token == bos: |
| 297 | completed[i] = True |
| 298 | else: |
| 299 | results[i].append(token) |
| 300 | masks[i].append(mask) |
| 301 | # Stop if all rows are completed |
| 302 | if all(completed): |
| 303 | break |
| 304 | return results, masks |
| 305 | |
| 306 | |
| 307 | if __name__ == "__main__": |