MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / main

Function main

examples/mmlu.py:358–475  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

356
357
358def main():
359 args = parse_args()
360 if args.tokenizer_dir is None:
361 args.tokenizer_dir = args.hf_model_dir
362 random.seed(RAND_SEED)
363 np.random.seed(RAND_SEED)
364 runtime_rank = tensorrt_llm.mpi_rank()
365
366 os.path.dirname(os.path.abspath(__file__))
367 data_fullpath = os.path.join(args.data_dir, "test")
368
369 subjects = sorted([
370 f.split("_test.csv")[0] for f in os.listdir(data_fullpath)
371 if "_test.csv" in f
372 ])
373
374 all_cors = []
375 subcat_cors = {
376 subcat: []
377 for subcat_lists in get_subcategories().values()
378 for subcat in subcat_lists
379 }
380 cat_cors = {cat: [] for cat in get_categories()}
381
382 # different handling if encoder-decoder models
383 is_enc_dec = read_is_enc_dec(
384 args.engine_dir if not args.test_hf else args.hf_model_dir,
385 args.test_hf)
386
387 model_name, model_version = read_model_name(
388 (args.engine_dir if not is_enc_dec else os.path.join(
389 args.engine_dir, 'encoder'))
390 if not args.test_hf else args.hf_model_dir, args.test_hf)
391
392 tokenizer, pad_id, end_id = load_tokenizer(
393 tokenizer_dir=args.tokenizer_dir,
394 vocab_file=args.vocab_file,
395 model_name=model_name,
396 model_version=model_version,
397 )
398
399 if args.test_trt_llm:
400 assert not args.test_hf, "Cannot test both TRT-LLM and HF"
401 runner_cls = ModelRunner if not PYTHON_BINDINGS else ModelRunnerCpp
402 runner_kwargs = {}
403 if PYTHON_BINDINGS:
404 runner_kwargs.update(max_beam_width=1)
405 runner_kwargs.update(
406 is_enc_dec=is_enc_dec,
407 max_tokens_in_paged_kv_cache=args.max_tokens_in_paged_kv_cache,
408 kv_cache_enable_block_reuse=args.kv_cache_enable_block_reuse,
409 kv_cache_free_gpu_memory_fraction=args.
410 kv_cache_free_gpu_memory_fraction,
411 cross_kv_cache_fraction=args.cross_kv_cache_fraction
412 if is_enc_dec else None,
413 enable_chunked_context=args.enable_chunked_context,
414 multi_block_mode=args.multi_block_mode)
415 model = runner_cls.from_dir(engine_dir=args.engine_dir,

Callers 1

mmlu.pyFile · 0.70

Calls 15

read_is_enc_decFunction · 0.90
read_model_nameFunction · 0.90
load_tokenizerFunction · 0.90
get_subcategoriesFunction · 0.85
get_categoriesFunction · 0.85
PipelineClass · 0.85
evaluateFunction · 0.85
meanMethod · 0.80
parse_argsFunction · 0.70
splitMethod · 0.45
valuesMethod · 0.45
updateMethod · 0.45

Tested by

no test coverage detected