MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / from_engine

Method from_engine

tensorrt_llm/runtime/model_runner.py:563–636  ·  view source on GitHub ↗
(
        cls,
        engine: Engine,
        *,
        max_output_len: Optional[int],
        lora_dir: Optional[List[str]],
        rank: int,
        debug_mode: bool,
        lora_ckpt_source: str,
        medusa_choices: List[List[int]],
        stream: torch.cuda.Stream,
        gpu_weights_percent: float,
        enable_context_fmha_fp32_acc: Optional[bool],
        multi_block_mode: Optional[bool],
    )

Source from the content-addressed store, hash-verified

561
562 @classmethod
563 def from_engine(
564 cls,
565 engine: Engine,
566 *,
567 max_output_len: Optional[int],
568 lora_dir: Optional[List[str]],
569 rank: int,
570 debug_mode: bool,
571 lora_ckpt_source: str,
572 medusa_choices: List[List[int]],
573 stream: torch.cuda.Stream,
574 gpu_weights_percent: float,
575 enable_context_fmha_fp32_acc: Optional[bool],
576 multi_block_mode: Optional[bool],
577 ) -> 'ModelRunner':
578 model_config = _engine_config_to_model_config(
579 engine.config, gpu_weights_percent=gpu_weights_percent)
580
581 if model_config.kv_cache_type == KVCacheType.DISABLED:
582 assert max_output_len == 1 or max_output_len is None, 'Disabled KV cache is intended for context phase only now.'
583
584 pretrained_config = engine.config.pretrained_config
585 build_config = engine.config.build_config
586 max_batch_size = build_config.max_batch_size
587 max_input_len = build_config.max_input_len
588 max_seq_len = build_config.max_seq_len
589 max_beam_width = build_config.max_beam_width
590 if 'GLM' in pretrained_config.architecture and pretrained_config.chatglm_version in [
591 'glm', 'chatglm'
592 ]:
593 session_cls = ChatGLMGenerationSession
594 else:
595 session_cls = GenerationSession
596 engine_buffer = engine.engine
597 runtime_mapping = pretrained_config.mapping
598
599 if medusa_choices is not None:
600 assert session_cls == GenerationSession, "Medusa is only supported by GenerationSession"
601
602 assert model_config.max_medusa_tokens > 0, \
603 "medusa_chioce is specified but model_config.max_medusa_tokens is 0."
604
605 if MpiComm.size() > runtime_mapping.gpus_per_node:
606 assert MpiComm.local_size() == runtime_mapping.gpus_per_node
607 if not DISABLE_TORCH_DEVICE_SET:
608 torch.cuda.set_device(rank % runtime_mapping.gpus_per_node)
609 session = session_cls(model_config,
610 engine_buffer,
611 runtime_mapping,
612 debug_mode=debug_mode,
613 stream=stream)
614 if session.runtime.engine.streamable_weights_size:
615 session.runtime._set_weight_streaming(gpu_weights_percent)
616
617 if session.use_lora_plugin:
618 lora_manager = LoraManager(mapping=runtime_mapping,
619 model_config=model_config)
620 if lora_dir is not None:

Callers 1

from_dirMethod · 0.45

Calls 5

load_from_ckptMethod · 0.95
LoraManagerClass · 0.85
sizeMethod · 0.45
_set_weight_streamingMethod · 0.45

Tested by

no test coverage detected