MCPcopy
hub / github.com/hpcaitech/ColossalAI / step

Method step

colossalai/inference/core/llm_engine.py:719–758  ·  view source on GitHub ↗

In each step, do the follows: 1. Run RequestHandler.schedule() and get the batch used for inference. 2. Get the input, inputinfo and output placeholder from the batchbucket 3. Run model to generate the next token 4. Update waiting list and run

(self)

Source from the content-addressed store, hash-verified

717 return input_ids, output_tensor, input_meta_data
718
719 def step(self) -> List[str]:
720 """
721 In each step, do the follows:
722 1. Run RequestHandler.schedule() and get the batch used for inference.
723 2. Get the input, inputinfo and output placeholder from the batchbucket
724 3. Run model to generate the next token
725 4. Update waiting list and running list in RequestHandler and get finished sequences.
726 5. Decode and return finished sequences.
727
728 Returns:
729 List[str]: Decoded finished sequences generated by one step.
730 """
731
732 batch = self.request_handler.schedule()
733
734 input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
735
736 if input_meta_data.use_cuda_graph:
737 model_executable = self.graph_runners[input_meta_data.batch_size]
738 else:
739 model_executable = self.model
740
741 # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
742 logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
743 if self.inference_config.pad_input:
744 logits = logits[:, -1, :]
745
746 if self.inference_config.enable_streamingllm:
747 updated_block_ids = batch.streamingllm_update_batch(
748 self.inference_config.start_token_size, self.inference_config.generated_token_size
749 )
750 self.request_handler.streamingllm_free_block_tables(updated_block_ids)
751
752 next_tokens = search_tokens(
753 self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
754 )
755 self.request_handler.append_next_tokens(next_tokens)
756 finished_sequences = self.request_handler.update()
757
758 return finished_sequences

Callers 3

generateMethod · 0.95
pixart_alpha_forwardFunction · 0.45
sd3_forwardFunction · 0.45

Calls 7

prepare_inputMethod · 0.95
search_tokensFunction · 0.90
append_next_tokensMethod · 0.80
scheduleMethod · 0.45
updateMethod · 0.45

Tested by

no test coverage detected