(
cls,
engine: Engine,
*,
max_output_len: Optional[int],
lora_dir: Optional[List[str]],
rank: int,
debug_mode: bool,
lora_ckpt_source: str,
medusa_choices: List[List[int]],
stream: torch.cuda.Stream,
gpu_weights_percent: float,
enable_context_fmha_fp32_acc: Optional[bool],
multi_block_mode: Optional[bool],
)
| 561 | |
| 562 | @classmethod |
| 563 | def from_engine( |
| 564 | cls, |
| 565 | engine: Engine, |
| 566 | *, |
| 567 | max_output_len: Optional[int], |
| 568 | lora_dir: Optional[List[str]], |
| 569 | rank: int, |
| 570 | debug_mode: bool, |
| 571 | lora_ckpt_source: str, |
| 572 | medusa_choices: List[List[int]], |
| 573 | stream: torch.cuda.Stream, |
| 574 | gpu_weights_percent: float, |
| 575 | enable_context_fmha_fp32_acc: Optional[bool], |
| 576 | multi_block_mode: Optional[bool], |
| 577 | ) -> 'ModelRunner': |
| 578 | model_config = _engine_config_to_model_config( |
| 579 | engine.config, gpu_weights_percent=gpu_weights_percent) |
| 580 | |
| 581 | if model_config.kv_cache_type == KVCacheType.DISABLED: |
| 582 | assert max_output_len == 1 or max_output_len is None, 'Disabled KV cache is intended for context phase only now.' |
| 583 | |
| 584 | pretrained_config = engine.config.pretrained_config |
| 585 | build_config = engine.config.build_config |
| 586 | max_batch_size = build_config.max_batch_size |
| 587 | max_input_len = build_config.max_input_len |
| 588 | max_seq_len = build_config.max_seq_len |
| 589 | max_beam_width = build_config.max_beam_width |
| 590 | if 'GLM' in pretrained_config.architecture and pretrained_config.chatglm_version in [ |
| 591 | 'glm', 'chatglm' |
| 592 | ]: |
| 593 | session_cls = ChatGLMGenerationSession |
| 594 | else: |
| 595 | session_cls = GenerationSession |
| 596 | engine_buffer = engine.engine |
| 597 | runtime_mapping = pretrained_config.mapping |
| 598 | |
| 599 | if medusa_choices is not None: |
| 600 | assert session_cls == GenerationSession, "Medusa is only supported by GenerationSession" |
| 601 | |
| 602 | assert model_config.max_medusa_tokens > 0, \ |
| 603 | "medusa_chioce is specified but model_config.max_medusa_tokens is 0." |
| 604 | |
| 605 | if MpiComm.size() > runtime_mapping.gpus_per_node: |
| 606 | assert MpiComm.local_size() == runtime_mapping.gpus_per_node |
| 607 | if not DISABLE_TORCH_DEVICE_SET: |
| 608 | torch.cuda.set_device(rank % runtime_mapping.gpus_per_node) |
| 609 | session = session_cls(model_config, |
| 610 | engine_buffer, |
| 611 | runtime_mapping, |
| 612 | debug_mode=debug_mode, |
| 613 | stream=stream) |
| 614 | if session.runtime.engine.streamable_weights_size: |
| 615 | session.runtime._set_weight_streaming(gpu_weights_percent) |
| 616 | |
| 617 | if session.use_lora_plugin: |
| 618 | lora_manager = LoraManager(mapping=runtime_mapping, |
| 619 | model_config=model_config) |
| 620 | if lora_dir is not None: |
no test coverage detected