(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
)
| 843 | |
| 844 | @torch.inference_mode() |
| 845 | def execute_model( |
| 846 | self, |
| 847 | seq_group_metadata_list: List[SequenceGroupMetadata], |
| 848 | kv_caches: List[torch.Tensor], |
| 849 | ) -> Optional[SamplerOutput]: |
| 850 | (input_tokens, input_positions, attn_metadata, sampling_metadata, |
| 851 | lora_requests, lora_mapping, multi_modal_input |
| 852 | ) = self.prepare_input_tensors(seq_group_metadata_list) |
| 853 | if self.lora_config: |
| 854 | self.set_active_loras(lora_requests, lora_mapping) |
| 855 | |
| 856 | # Currently cuda graph is only supported by the decode phase. |
| 857 | prefill_meta = attn_metadata.prefill_metadata |
| 858 | decode_meta = attn_metadata.decode_metadata |
| 859 | if prefill_meta is None and decode_meta.use_cuda_graph: |
| 860 | graph_batch_size = input_tokens.shape[0] |
| 861 | model_executable = self.graph_runners[graph_batch_size] |
| 862 | else: |
| 863 | model_executable = self.model |
| 864 | execute_model_kwargs = { |
| 865 | "input_ids": input_tokens, |
| 866 | "positions": input_positions, |
| 867 | "kv_caches": kv_caches, |
| 868 | "attn_metadata": attn_metadata, |
| 869 | } |
| 870 | if self.vision_language_config: |
| 871 | execute_model_kwargs.update({"image_input": multi_modal_input}) |
| 872 | hidden_states = model_executable(**execute_model_kwargs) |
| 873 | |
| 874 | # Compute the logits. |
| 875 | logits = self.model.compute_logits(hidden_states, sampling_metadata) |
| 876 | |
| 877 | # Only perform sampling in the driver worker. |
| 878 | if not sampling_metadata.perform_sampling: |
| 879 | return None |
| 880 | |
| 881 | # Sample the next token. |
| 882 | output = self.model.sample( |
| 883 | logits=logits, |
| 884 | sampling_metadata=sampling_metadata, |
| 885 | ) |
| 886 | return output |
| 887 | |
| 888 | @torch.inference_mode() |
| 889 | def profile_run(self) -> None: |
no test coverage detected