Instantiate one of the configuration classes of the library from a pretrained model configuration. Args: pretrained_model_name_or_path (str): the pretrained model path trust_remote_code (bool): Whether or not to allow for custom models define
(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: bool = False,
dtype: str = 'auto',
dist_config: DistConfig = None,
hf_overrides: dict[str, Any] = None,
is_draft_model: bool = False,
spec_method: str = None,
num_spec_tokens: int = 0,
model_format: str = None,
device_type: str = 'auto',
block_size: int = 64,
)
| 380 | |
| 381 | @classmethod |
| 382 | def from_pretrained( |
| 383 | cls, |
| 384 | pretrained_model_name_or_path: str, |
| 385 | trust_remote_code: bool = False, |
| 386 | dtype: str = 'auto', |
| 387 | dist_config: DistConfig = None, |
| 388 | hf_overrides: dict[str, Any] = None, |
| 389 | is_draft_model: bool = False, |
| 390 | spec_method: str = None, |
| 391 | num_spec_tokens: int = 0, |
| 392 | model_format: str = None, |
| 393 | device_type: str = 'auto', |
| 394 | block_size: int = 64, |
| 395 | ): |
| 396 | """Instantiate one of the configuration classes of the library from a |
| 397 | pretrained model configuration. |
| 398 | |
| 399 | Args: |
| 400 | pretrained_model_name_or_path (str): the pretrained model path |
| 401 | trust_remote_code (bool): Whether or not to allow for custom |
| 402 | models defined on the Hub in their own modeling files. |
| 403 | dtype (str): user specified data type for model weights and |
| 404 | activations. Refer to `PyTorchEngineConfig` for details |
| 405 | hf_overrides (dict[str, Any]): overrides for the HF config. |
| 406 | model_format (str): the quantization format of the model. |
| 407 | """ |
| 408 | from transformers import AutoConfig |
| 409 | |
| 410 | from lmdeploy.pytorch.transformers import config_from_pretrained |
| 411 | hf_config = config_from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) |
| 412 | if getattr(hf_config, 'model_type', None) in ['phi3']: |
| 413 | # phi3 + trust_remote_code leads to error when tp. |
| 414 | hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path) |
| 415 | |
| 416 | # update quantization config |
| 417 | hf_config = _patch_quantization_config(hf_config, model_format=model_format) |
| 418 | |
| 419 | model_config = cls.from_hf_config( |
| 420 | hf_config, |
| 421 | pretrained_model_name_or_path, |
| 422 | dtype=dtype, |
| 423 | dist_config=dist_config, |
| 424 | is_draft_model=is_draft_model, |
| 425 | spec_method=spec_method, |
| 426 | num_spec_tokens=num_spec_tokens, |
| 427 | device_type=device_type, |
| 428 | ) |
| 429 | fp32_lm_head = False |
| 430 | if hf_overrides is not None: |
| 431 | logger.warning(f'Overriding HF config with {hf_overrides}') |
| 432 | fp32_lm_head = hf_overrides.pop('fp32_lm_head', False) |
| 433 | override_hf_config(model_config.hf_config, hf_overrides) |
| 434 | |
| 435 | # for fp32 head |
| 436 | model_config.fp32_lm_head = fp32_lm_head |
| 437 | model_config.tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False) |
| 438 | |
| 439 | # for serialization of transformers modules |