Specify how to select the next token, based on the current trace and logits Parameters ---------- tokens : Tensor, shape = (n_batch, current_sequence_length) all tokens in the context so far, including the prefix and sot_sequence tokens logits : Tensor,
(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor)
| 283 | """Initialize any stateful variables for decoding a new sequence""" |
| 284 | |
| 285 | def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: |
| 286 | """Specify how to select the next token, based on the current trace and logits |
| 287 | |
| 288 | Parameters |
| 289 | ---------- |
| 290 | tokens : Tensor, shape = (n_batch, current_sequence_length) |
| 291 | all tokens in the context so far, including the prefix and sot_sequence tokens |
| 292 | |
| 293 | logits : Tensor, shape = (n_batch, vocab_size) |
| 294 | per-token logits of the probability distribution at the current step |
| 295 | |
| 296 | sum_logprobs : Tensor, shape = (n_batch) |
| 297 | cumulative log probabilities for each sequence |
| 298 | |
| 299 | Returns |
| 300 | ------- |
| 301 | tokens : Tensor, shape = (n_batch, current_sequence_length + 1) |
| 302 | the tokens, appended with the selected next token |
| 303 | |
| 304 | completed : bool |
| 305 | True if all sequences has reached the end of text |
| 306 | |
| 307 | """ |
| 308 | raise NotImplementedError |
| 309 | |
| 310 | def finalize( |
| 311 | self, tokens: Tensor, sum_logprobs: Tensor |
no outgoing calls