MCPcopy
hub / github.com/InternLM/lmdeploy / from_pretrained

Method from_pretrained

lmdeploy/pytorch/config.py:382–445  ·  view source on GitHub ↗

Instantiate one of the configuration classes of the library from a pretrained model configuration. Args: pretrained_model_name_or_path (str): the pretrained model path trust_remote_code (bool): Whether or not to allow for custom models define

(
        cls,
        pretrained_model_name_or_path: str,
        trust_remote_code: bool = False,
        dtype: str = 'auto',
        dist_config: DistConfig = None,
        hf_overrides: dict[str, Any] = None,
        is_draft_model: bool = False,
        spec_method: str = None,
        num_spec_tokens: int = 0,
        model_format: str = None,
        device_type: str = 'auto',
        block_size: int = 64,
    )

Source from the content-addressed store, hash-verified

380
381 @classmethod
382 def from_pretrained(
383 cls,
384 pretrained_model_name_or_path: str,
385 trust_remote_code: bool = False,
386 dtype: str = 'auto',
387 dist_config: DistConfig = None,
388 hf_overrides: dict[str, Any] = None,
389 is_draft_model: bool = False,
390 spec_method: str = None,
391 num_spec_tokens: int = 0,
392 model_format: str = None,
393 device_type: str = 'auto',
394 block_size: int = 64,
395 ):
396 """Instantiate one of the configuration classes of the library from a
397 pretrained model configuration.
398
399 Args:
400 pretrained_model_name_or_path (str): the pretrained model path
401 trust_remote_code (bool): Whether or not to allow for custom
402 models defined on the Hub in their own modeling files.
403 dtype (str): user specified data type for model weights and
404 activations. Refer to `PyTorchEngineConfig` for details
405 hf_overrides (dict[str, Any]): overrides for the HF config.
406 model_format (str): the quantization format of the model.
407 """
408 from transformers import AutoConfig
409
410 from lmdeploy.pytorch.transformers import config_from_pretrained
411 hf_config = config_from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
412 if getattr(hf_config, 'model_type', None) in ['phi3']:
413 # phi3 + trust_remote_code leads to error when tp.
414 hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
415
416 # update quantization config
417 hf_config = _patch_quantization_config(hf_config, model_format=model_format)
418
419 model_config = cls.from_hf_config(
420 hf_config,
421 pretrained_model_name_or_path,
422 dtype=dtype,
423 dist_config=dist_config,
424 is_draft_model=is_draft_model,
425 spec_method=spec_method,
426 num_spec_tokens=num_spec_tokens,
427 device_type=device_type,
428 )
429 fp32_lm_head = False
430 if hf_overrides is not None:
431 logger.warning(f'Overriding HF config with {hf_overrides}')
432 fp32_lm_head = hf_overrides.pop('fp32_lm_head', False)
433 override_hf_config(model_config.hf_config, hf_overrides)
434
435 # for fp32 head
436 model_config.fp32_lm_head = fp32_lm_head
437 model_config.tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False)
438
439 # for serialization of transformers modules

Callers 15

get_hf_gen_cfgFunction · 0.45
__init__Method · 0.45
__init__Method · 0.45
get_model_archFunction · 0.45
__init__Method · 0.45
mainFunction · 0.45
build_preprocessorMethod · 0.45
build_modelMethod · 0.45
build_preprocessorMethod · 0.45
build_modelMethod · 0.45
build_modelMethod · 0.45

Calls 7

config_from_pretrainedFunction · 0.90
override_hf_configFunction · 0.85
from_hf_configMethod · 0.80
popMethod · 0.45
from_configMethod · 0.45