(self, model_dir: str, trust_remote_code: bool = False)
| 68 | self.token2id = {} |
| 69 | |
| 70 | def _check_transformers_version(self, model_dir: str, trust_remote_code: bool = False): |
| 71 | import transformers |
| 72 | from packaging import version |
| 73 | |
| 74 | from lmdeploy.archs import get_model_arch |
| 75 | |
| 76 | logger = get_logger('lmdeploy') |
| 77 | |
| 78 | current_transformers_version = version.parse(transformers.__version__) |
| 79 | cfg = get_model_arch(model_dir, trust_remote_code=trust_remote_code)[1] |
| 80 | cfg_ver = getattr(cfg, 'transformers_version', None) |
| 81 | if cfg_ver is None: |
| 82 | llm_config = getattr(cfg, 'llm_config', None) |
| 83 | if llm_config: |
| 84 | cfg_ver = getattr(llm_config, 'transformers_version', None) |
| 85 | if cfg_ver is None: |
| 86 | return |
| 87 | required_transformers_version = version.parse(cfg_ver) |
| 88 | if current_transformers_version < required_transformers_version: |
| 89 | logger.warning( |
| 90 | f'The current version of `transformers` is transformers=={current_transformers_version}, ' # noqa: E501 |
| 91 | f'which is lower than the required version transformers=={required_transformers_version}. ' # noqa: E501 |
| 92 | 'Please upgrade to the required version.') |
| 93 | |
| 94 | def get_vocab(self): |
| 95 | """Get vocab.""" |
no test coverage detected