In each step, do the follows: 1. Run RequestHandler.schedule() and get the batch used for inference. 2. Get the input, inputinfo and output placeholder from the batchbucket 3. Run model to generate the next token 4. Update waiting list and run
(self)
| 717 | return input_ids, output_tensor, input_meta_data |
| 718 | |
| 719 | def step(self) -> List[str]: |
| 720 | """ |
| 721 | In each step, do the follows: |
| 722 | 1. Run RequestHandler.schedule() and get the batch used for inference. |
| 723 | 2. Get the input, inputinfo and output placeholder from the batchbucket |
| 724 | 3. Run model to generate the next token |
| 725 | 4. Update waiting list and running list in RequestHandler and get finished sequences. |
| 726 | 5. Decode and return finished sequences. |
| 727 | |
| 728 | Returns: |
| 729 | List[str]: Decoded finished sequences generated by one step. |
| 730 | """ |
| 731 | |
| 732 | batch = self.request_handler.schedule() |
| 733 | |
| 734 | input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) |
| 735 | |
| 736 | if input_meta_data.use_cuda_graph: |
| 737 | model_executable = self.graph_runners[input_meta_data.batch_size] |
| 738 | else: |
| 739 | model_executable = self.model |
| 740 | |
| 741 | # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. |
| 742 | logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) |
| 743 | if self.inference_config.pad_input: |
| 744 | logits = logits[:, -1, :] |
| 745 | |
| 746 | if self.inference_config.enable_streamingllm: |
| 747 | updated_block_ids = batch.streamingllm_update_batch( |
| 748 | self.inference_config.start_token_size, self.inference_config.generated_token_size |
| 749 | ) |
| 750 | self.request_handler.streamingllm_free_block_tables(updated_block_ids) |
| 751 | |
| 752 | next_tokens = search_tokens( |
| 753 | self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids |
| 754 | ) |
| 755 | self.request_handler.append_next_tokens(next_tokens) |
| 756 | finished_sequences = self.request_handler.update() |
| 757 | |
| 758 | return finished_sequences |
no test coverage detected