MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / main

Function main

examples/run.py:316–705  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

314
315
316def main(args):
317 runtime_rank = tensorrt_llm.mpi_rank()
318 logger.set_level(args.log_level)
319
320 # different handling if encoder-decoder models
321 is_enc_dec = {'encoder', 'decoder'}.issubset({
322 name
323 for name in os.listdir(args.engine_dir)
324 if os.path.isdir(os.path.join(args.engine_dir, name))
325 })
326 if is_enc_dec:
327 logger.warning(
328 "This path is an encoder-decoder model. Using different handling.")
329 assert not args.use_py_session, "Encoder-decoder models don't have a unified python runtime, please use its own examples/models/core/enc_dec/run.py instead."
330
331 model_name, model_version = read_model_name(
332 args.engine_dir if not is_enc_dec else os.path.
333 join(args.engine_dir, 'encoder'))
334
335 if args.tokenizer_dir is None and model_name in DEFAULT_HF_MODEL_DIRS:
336 logger.warning(
337 "tokenizer_dir is not specified. Try to infer from model_name, but this may be incorrect."
338 )
339 args.tokenizer_dir = DEFAULT_HF_MODEL_DIRS[model_name]
340
341 tokenizer, pad_id, end_id = load_tokenizer(
342 tokenizer_dir=args.tokenizer_dir,
343 vocab_file=args.vocab_file,
344 model_name=model_name,
345 model_version=model_version,
346 tokenizer_type=args.tokenizer_type,
347 )
348
349 if args.end_id:
350 end_id = args.end_id
351
352 prompt_template = None
353 if args.use_prompt_template and model_name in DEFAULT_PROMPT_TEMPLATES:
354 prompt_template = DEFAULT_PROMPT_TEMPLATES[model_name]
355
356 batch_input_ids = parse_input(tokenizer=tokenizer,
357 input_text=args.input_text,
358 prompt_template=prompt_template,
359 input_file=args.input_file,
360 add_special_tokens=args.add_special_tokens,
361 max_input_length=args.max_input_length,
362 pad_id=pad_id,
363 num_prepend_vtokens=args.num_prepend_vtokens,
364 model_name=model_name,
365 model_version=model_version)
366
367 stop_words_list = None
368 if args.stop_words:
369 stop_words_list = tensorrt_llm.runtime.decode_words_list(
370 args.stop_words, tokenizer)
371 if model_version == 'glm4': # add default stop token ids for GLM-4
372 glm4_stop_ids = [[151329], [151336], [151338]]
373 if stop_words_list is None:

Callers 1

run.pyFile · 0.70

Calls 15

read_model_nameFunction · 0.90
load_tokenizerFunction · 0.90
prepare_enc_dec_inputsFunction · 0.90
get_beam_width_arrayFunction · 0.90
run_dtm_ngramFunction · 0.90
throttle_generatorFunction · 0.90
maxFunction · 0.85
print_outputFunction · 0.85
set_levelMethod · 0.80
synchronizeMethod · 0.80

Tested by

no test coverage detected