MCPcopy
hub / github.com/yangjianxin1/Firefly / load_tokenizer

Function load_tokenizer

train.py:187–218  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

185
186
187def load_tokenizer(args):
188 config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)
189 # 加载tokenzier
190 tokenizer = AutoTokenizer.from_pretrained(
191 args.model_name_or_path,
192 trust_remote_code=True,
193 # llama不支持fast
194 use_fast=False if config.model_type == 'llama' or config.model_type == 'internlm2' else True
195 )
196
197 # 部分模型的base与chat版本的tokenizer存在差异
198 if 'internlm2' in args.model_name_or_path.lower():
199 tokenizer._added_tokens_encoder.update({'<|im_start|>': 92543})
200 tokenizer._added_tokens_encoder.update({'<|im_end|>': 92542})
201 tokenizer._added_tokens_decoder.update({92543: AddedToken('<|im_start|>')})
202 tokenizer._added_tokens_decoder.update({92542: AddedToken('<|im_end|>')})
203 tokenizer.add_special_tokens({'additional_special_tokens': ['<|im_start|>', '<|im_end|>']})
204 elif 'orion' in args.model_name_or_path.lower():
205 tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>'})
206 elif 'gemma' in args.model_name_or_path.lower():
207 tokenizer.add_special_tokens({'additional_special_tokens': ['<start_of_turn>', '<end_of_turn>']})
208
209 if tokenizer.__class__.__name__ == 'QWenTokenizer':
210 tokenizer.pad_token_id = tokenizer.eod_id
211 tokenizer.bos_token_id = tokenizer.eod_id
212 tokenizer.eos_token_id = tokenizer.eod_id
213 if tokenizer.pad_token is None:
214 tokenizer.pad_token = tokenizer.eos_token
215 assert tokenizer.pad_token_id is not None, "pad_token_id should not be None"
216 assert tokenizer.eos_token_id is not None, "eos_token_id should not be None"
217 logger.info(f'vocab_size of tokenizer: {tokenizer.vocab_size}')
218 return tokenizer
219
220
221def load_unsloth_model(args, training_args):

Callers 1

init_componentsFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected