MCPcopy
hub / github.com/FoundationVision/LlamaGen / execute_model

Method execute_model

autoregressive/serve/model_runner.py:845–886  ·  view source on GitHub ↗
(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        kv_caches: List[torch.Tensor],
    )

Source from the content-addressed store, hash-verified

843
844 @torch.inference_mode()
845 def execute_model(
846 self,
847 seq_group_metadata_list: List[SequenceGroupMetadata],
848 kv_caches: List[torch.Tensor],
849 ) -> Optional[SamplerOutput]:
850 (input_tokens, input_positions, attn_metadata, sampling_metadata,
851 lora_requests, lora_mapping, multi_modal_input
852 ) = self.prepare_input_tensors(seq_group_metadata_list)
853 if self.lora_config:
854 self.set_active_loras(lora_requests, lora_mapping)
855
856 # Currently cuda graph is only supported by the decode phase.
857 prefill_meta = attn_metadata.prefill_metadata
858 decode_meta = attn_metadata.decode_metadata
859 if prefill_meta is None and decode_meta.use_cuda_graph:
860 graph_batch_size = input_tokens.shape[0]
861 model_executable = self.graph_runners[graph_batch_size]
862 else:
863 model_executable = self.model
864 execute_model_kwargs = {
865 "input_ids": input_tokens,
866 "positions": input_positions,
867 "kv_caches": kv_caches,
868 "attn_metadata": attn_metadata,
869 }
870 if self.vision_language_config:
871 execute_model_kwargs.update({"image_input": multi_modal_input})
872 hidden_states = model_executable(**execute_model_kwargs)
873
874 # Compute the logits.
875 logits = self.model.compute_logits(hidden_states, sampling_metadata)
876
877 # Only perform sampling in the driver worker.
878 if not sampling_metadata.perform_sampling:
879 return None
880
881 # Sample the next token.
882 output = self.model.sample(
883 logits=logits,
884 sampling_metadata=sampling_metadata,
885 )
886 return output
887
888 @torch.inference_mode()
889 def profile_run(self) -> None:

Callers 1

profile_runMethod · 0.95

Calls 5

prepare_input_tensorsMethod · 0.95
set_active_lorasMethod · 0.95
updateMethod · 0.80
compute_logitsMethod · 0.80
sampleMethod · 0.80

Tested by

no test coverage detected