A wrapper of transformers' AutoTokenizer. Args: model_dir: the directory of the tokenizer model.
| 37 | |
| 38 | |
| 39 | class HuggingFaceTokenizer: |
| 40 | """A wrapper of transformers' AutoTokenizer. |
| 41 | |
| 42 | Args: |
| 43 | model_dir: the directory of the tokenizer model. |
| 44 | """ |
| 45 | |
| 46 | def __init__(self, model_dir: str, trust_remote_code: bool = False): |
| 47 | self._check_transformers_version(model_dir, trust_remote_code=trust_remote_code) |
| 48 | from transformers import AutoTokenizer |
| 49 | self.logger = get_logger('lmdeploy') |
| 50 | self.model = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=trust_remote_code) |
| 51 | self._prefix_space_tokens = None |
| 52 | |
| 53 | if self.model.eos_token_id is None: |
| 54 | generation_config_file = osp.join(model_dir, 'generation_config.json') |
| 55 | if osp.exists(generation_config_file): |
| 56 | with open(generation_config_file) as f: |
| 57 | cfg = json.load(f) |
| 58 | self.model.eos_token_id = cfg['eos_token_id'] |
| 59 | elif hasattr(self.model, 'eod_id'): # Qwen remote |
| 60 | self.model.eos_token_id = self.model.eod_id |
| 61 | |
| 62 | # for stop words |
| 63 | self._vocab_size_with_added: int = None |
| 64 | self._maybe_decode_bytes: bool = None |
| 65 | # TODO maybe lack a constant.py |
| 66 | self._indexes_tokens_deque = deque(maxlen=10) |
| 67 | self.max_indexes_num = 5 |
| 68 | self.token2id = {} |
| 69 | |
| 70 | def _check_transformers_version(self, model_dir: str, trust_remote_code: bool = False): |
| 71 | import transformers |
| 72 | from packaging import version |
| 73 | |
| 74 | from lmdeploy.archs import get_model_arch |
| 75 | |
| 76 | logger = get_logger('lmdeploy') |
| 77 | |
| 78 | current_transformers_version = version.parse(transformers.__version__) |
| 79 | cfg = get_model_arch(model_dir, trust_remote_code=trust_remote_code)[1] |
| 80 | cfg_ver = getattr(cfg, 'transformers_version', None) |
| 81 | if cfg_ver is None: |
| 82 | llm_config = getattr(cfg, 'llm_config', None) |
| 83 | if llm_config: |
| 84 | cfg_ver = getattr(llm_config, 'transformers_version', None) |
| 85 | if cfg_ver is None: |
| 86 | return |
| 87 | required_transformers_version = version.parse(cfg_ver) |
| 88 | if current_transformers_version < required_transformers_version: |
| 89 | logger.warning( |
| 90 | f'The current version of `transformers` is transformers=={current_transformers_version}, ' # noqa: E501 |
| 91 | f'which is lower than the required version transformers=={required_transformers_version}. ' # noqa: E501 |
| 92 | 'Please upgrade to the required version.') |
| 93 | |
| 94 | def get_vocab(self): |
| 95 | """Get vocab.""" |
| 96 | return self.model.get_vocab() |
no outgoing calls