Build model agent executor.
(
model_path: str,
cache_config: CacheConfig,
backend_config: BackendConfig,
dist_config: DistConfig,
misc_config: MiscConfig,
adapters: dict[str, str] = None,
device_type: str = 'cuda',
distributed_executor_backend: str = None,
dtype: str = 'auto',
specdecode_config: SpecDecodeConfig = None,
trust_remote_code: bool = False,
)
| 53 | |
| 54 | |
| 55 | def build_executor( |
| 56 | model_path: str, |
| 57 | cache_config: CacheConfig, |
| 58 | backend_config: BackendConfig, |
| 59 | dist_config: DistConfig, |
| 60 | misc_config: MiscConfig, |
| 61 | adapters: dict[str, str] = None, |
| 62 | device_type: str = 'cuda', |
| 63 | distributed_executor_backend: str = None, |
| 64 | dtype: str = 'auto', |
| 65 | specdecode_config: SpecDecodeConfig = None, |
| 66 | trust_remote_code: bool = False, |
| 67 | ) -> ExecutorBase: |
| 68 | """Build model agent executor.""" |
| 69 | logger = get_logger('lmdeploy') |
| 70 | dp = dist_config.dp |
| 71 | world_size = dist_config.world_size |
| 72 | |
| 73 | model_config = ModelConfig.from_pretrained( |
| 74 | model_path, |
| 75 | trust_remote_code=trust_remote_code, |
| 76 | dtype=dtype, |
| 77 | hf_overrides=misc_config.hf_overrides, |
| 78 | dist_config=dist_config, |
| 79 | is_draft_model=False, |
| 80 | spec_method=None if specdecode_config is None else specdecode_config.method, |
| 81 | num_spec_tokens=0 if specdecode_config is None else specdecode_config.num_speculative_tokens, |
| 82 | model_format=misc_config.model_format, |
| 83 | device_type=device_type, |
| 84 | block_size=cache_config.block_size, |
| 85 | ) |
| 86 | |
| 87 | if distributed_executor_backend is None: |
| 88 | distributed_executor_backend = get_distributed_executor_backend(world_size, dp, device_type, logger) |
| 89 | |
| 90 | if dp > 1: |
| 91 | assert distributed_executor_backend == 'ray', ( |
| 92 | 'dp>1 requires distributed_executor_backend="ray", ', |
| 93 | f'get distributed_executor_backend="{distributed_executor_backend}"') |
| 94 | |
| 95 | if misc_config.empty_init: |
| 96 | assert distributed_executor_backend == 'ray', ( |
| 97 | 'empty_init requires distributed_executor_backend="ray", ', |
| 98 | f'get distributed_executor_backend="{distributed_executor_backend}"') |
| 99 | |
| 100 | if distributed_executor_backend is not None: |
| 101 | logger.info(f'Build <{distributed_executor_backend}> executor.') |
| 102 | if distributed_executor_backend == 'uni': |
| 103 | assert world_size == 1, 'uni executor only support world_size==1.' |
| 104 | from .uni_executor import UniExecutor |
| 105 | return UniExecutor( |
| 106 | model_path=model_path, |
| 107 | model_config=model_config, |
| 108 | cache_config=cache_config, |
| 109 | backend_config=backend_config, |
| 110 | misc_config=misc_config, |
| 111 | adapters=adapters, |
| 112 | device_type=device_type, |
no test coverage detected