MCPcopy
hub / github.com/InternLM/lmdeploy / HuggingFaceTokenizer

Class HuggingFaceTokenizer

lmdeploy/tokenizer.py:39–349  ·  view source on GitHub ↗

A wrapper of transformers' AutoTokenizer. Args: model_dir: the directory of the tokenizer model.

Source from the content-addressed store, hash-verified

37
38
39class HuggingFaceTokenizer:
40 """A wrapper of transformers' AutoTokenizer.
41
42 Args:
43 model_dir: the directory of the tokenizer model.
44 """
45
46 def __init__(self, model_dir: str, trust_remote_code: bool = False):
47 self._check_transformers_version(model_dir, trust_remote_code=trust_remote_code)
48 from transformers import AutoTokenizer
49 self.logger = get_logger('lmdeploy')
50 self.model = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
51 self._prefix_space_tokens = None
52
53 if self.model.eos_token_id is None:
54 generation_config_file = osp.join(model_dir, 'generation_config.json')
55 if osp.exists(generation_config_file):
56 with open(generation_config_file) as f:
57 cfg = json.load(f)
58 self.model.eos_token_id = cfg['eos_token_id']
59 elif hasattr(self.model, 'eod_id'): # Qwen remote
60 self.model.eos_token_id = self.model.eod_id
61
62 # for stop words
63 self._vocab_size_with_added: int = None
64 self._maybe_decode_bytes: bool = None
65 # TODO maybe lack a constant.py
66 self._indexes_tokens_deque = deque(maxlen=10)
67 self.max_indexes_num = 5
68 self.token2id = {}
69
70 def _check_transformers_version(self, model_dir: str, trust_remote_code: bool = False):
71 import transformers
72 from packaging import version
73
74 from lmdeploy.archs import get_model_arch
75
76 logger = get_logger('lmdeploy')
77
78 current_transformers_version = version.parse(transformers.__version__)
79 cfg = get_model_arch(model_dir, trust_remote_code=trust_remote_code)[1]
80 cfg_ver = getattr(cfg, 'transformers_version', None)
81 if cfg_ver is None:
82 llm_config = getattr(cfg, 'llm_config', None)
83 if llm_config:
84 cfg_ver = getattr(llm_config, 'transformers_version', None)
85 if cfg_ver is None:
86 return
87 required_transformers_version = version.parse(cfg_ver)
88 if current_transformers_version < required_transformers_version:
89 logger.warning(
90 f'The current version of `transformers` is transformers=={current_transformers_version}, ' # noqa: E501
91 f'which is lower than the required version transformers=={required_transformers_version}. ' # noqa: E501
92 'Please upgrade to the required version.')
93
94 def get_vocab(self):
95 """Get vocab."""
96 return self.model.get_vocab()

Callers 3

__init__Method · 0.85

Calls

no outgoing calls

Tested by 2