MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / main

Function main

examples/eval_long_context.py:149–326  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

147
148
149def main(args):
150 # model_name = "yarn-mistral"
151 runtime_rank = tensorrt_llm.mpi_rank()
152 logger.set_level(args.log_level)
153
154 print(json.dumps(vars(args), indent=4))
155 data_name = args.task
156
157 # Model
158 max_tokens = DATA_NAME_TO_MAX_NEW_TOKENS[data_name]
159
160 model_name, model_version = read_model_name(args.engine_dir)
161 if args.tokenizer_dir is None:
162 logger.warning(
163 "tokenizer_dir is not specified. Try to infer from model_name, but this may be incorrect."
164 )
165 args.tokenizer_dir = DEFAULT_HF_MODEL_DIRS[model_name]
166
167 tokenizer, pad_id, end_id = load_tokenizer(
168 tokenizer_dir=args.tokenizer_dir,
169 vocab_file=args.vocab_file,
170 model_name=model_name,
171 model_version=model_version,
172 tokenizer_type=args.tokenizer_type,
173 )
174
175 if not PYTHON_BINDINGS and not args.use_py_session:
176 logger.warning(
177 "Python bindings of C++ session is unavailable, fallback to Python session."
178 )
179 args.use_py_session = True
180 if args.debug_mode and not args.use_py_session:
181 logger.warning(
182 "Debug mode is not supported in C++ session for now, fallback to Python session."
183 )
184 args.use_py_session = True
185 runner_cls = ModelRunner if args.use_py_session else ModelRunnerCpp
186 runner_kwargs = dict(
187 engine_dir=args.engine_dir,
188 lora_dir=args.lora_dir,
189 rank=runtime_rank,
190 debug_mode=args.debug_mode,
191 lora_ckpt_source=args.lora_ckpt_source,
192 gpu_weights_percent=args.gpu_weights_percent,
193 )
194 if args.medusa_choices is not None:
195 args.medusa_choices = ast.literal_eval(args.medusa_choices)
196 assert args.temperature == 1.0, "Medusa should use temperature == 1.0"
197 assert args.num_beams == 1, "Medusa should use num_beams == 1"
198 runner_kwargs.update(medusa_choices=args.medusa_choices)
199 if not args.use_py_session:
200 runner_kwargs.update(
201 max_batch_size=args.batch_size,
202 max_input_len=args.max_input_length,
203 max_output_len=max_tokens,
204 max_beam_width=args.num_beams,
205 max_attention_window_size=args.max_attention_window_size,
206 sink_token_length=args.sink_token_length,

Callers 1

Calls 15

read_model_nameFunction · 0.90
load_tokenizerFunction · 0.90
load_dataFunction · 0.90
create_promptFunction · 0.90
get_answerFunction · 0.90
dump_jsonlFunction · 0.90
compute_scoresFunction · 0.90
set_levelMethod · 0.80
synchronizeMethod · 0.80
batch_decodeMethod · 0.80
parse_inputFunction · 0.70
warningMethod · 0.45

Tested by

no test coverage detected