(args)
| 314 | |
| 315 | |
| 316 | def main(args): |
| 317 | runtime_rank = tensorrt_llm.mpi_rank() |
| 318 | logger.set_level(args.log_level) |
| 319 | |
| 320 | # different handling if encoder-decoder models |
| 321 | is_enc_dec = {'encoder', 'decoder'}.issubset({ |
| 322 | name |
| 323 | for name in os.listdir(args.engine_dir) |
| 324 | if os.path.isdir(os.path.join(args.engine_dir, name)) |
| 325 | }) |
| 326 | if is_enc_dec: |
| 327 | logger.warning( |
| 328 | "This path is an encoder-decoder model. Using different handling.") |
| 329 | assert not args.use_py_session, "Encoder-decoder models don't have a unified python runtime, please use its own examples/models/core/enc_dec/run.py instead." |
| 330 | |
| 331 | model_name, model_version = read_model_name( |
| 332 | args.engine_dir if not is_enc_dec else os.path. |
| 333 | join(args.engine_dir, 'encoder')) |
| 334 | |
| 335 | if args.tokenizer_dir is None and model_name in DEFAULT_HF_MODEL_DIRS: |
| 336 | logger.warning( |
| 337 | "tokenizer_dir is not specified. Try to infer from model_name, but this may be incorrect." |
| 338 | ) |
| 339 | args.tokenizer_dir = DEFAULT_HF_MODEL_DIRS[model_name] |
| 340 | |
| 341 | tokenizer, pad_id, end_id = load_tokenizer( |
| 342 | tokenizer_dir=args.tokenizer_dir, |
| 343 | vocab_file=args.vocab_file, |
| 344 | model_name=model_name, |
| 345 | model_version=model_version, |
| 346 | tokenizer_type=args.tokenizer_type, |
| 347 | ) |
| 348 | |
| 349 | if args.end_id: |
| 350 | end_id = args.end_id |
| 351 | |
| 352 | prompt_template = None |
| 353 | if args.use_prompt_template and model_name in DEFAULT_PROMPT_TEMPLATES: |
| 354 | prompt_template = DEFAULT_PROMPT_TEMPLATES[model_name] |
| 355 | |
| 356 | batch_input_ids = parse_input(tokenizer=tokenizer, |
| 357 | input_text=args.input_text, |
| 358 | prompt_template=prompt_template, |
| 359 | input_file=args.input_file, |
| 360 | add_special_tokens=args.add_special_tokens, |
| 361 | max_input_length=args.max_input_length, |
| 362 | pad_id=pad_id, |
| 363 | num_prepend_vtokens=args.num_prepend_vtokens, |
| 364 | model_name=model_name, |
| 365 | model_version=model_version) |
| 366 | |
| 367 | stop_words_list = None |
| 368 | if args.stop_words: |
| 369 | stop_words_list = tensorrt_llm.runtime.decode_words_list( |
| 370 | args.stop_words, tokenizer) |
| 371 | if model_version == 'glm4': # add default stop token ids for GLM-4 |
| 372 | glm4_stop_ids = [[151329], [151336], [151338]] |
| 373 | if stop_words_list is None: |
no test coverage detected