MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / eval_hf

Function eval_hf

examples/summarize.py:414–525  ·  view source on GitHub ↗
(datapoint,
                eval_task='summarize',
                eval_ppl=False,
                add_special_tokens=True,
                min_input_length=0)

Source from the content-addressed store, hash-verified

412 return [], [], [], {}
413
414 def eval_hf(datapoint,
415 eval_task='summarize',
416 eval_ppl=False,
417 add_special_tokens=True,
418 min_input_length=0):
419 batch_size = len(datapoint[dataset_input_key])
420 if batch_size > 1:
421 logger.warning(
422 f"HF does not support batch_size > 1 to verify correctness due to padding. Current batch size is {batch_size}"
423 )
424 batch_input_ids = _prepare_inputs(datapoint[dataset_input_key],
425 eval_task=eval_task,
426 add_special_tokens=add_special_tokens,
427 min_input_length=min_input_length)
428 batch_size = len(batch_input_ids)
429 if batch_size == 0:
430 return [], [], [], [[] for _ in range(batch_size)]
431 input_lengths = [x.size(0) for x in batch_input_ids]
432 # Left padding for HF
433 max_length = max(input_lengths)
434 paddings = [
435 torch.ones(max_length - l, dtype=torch.int32) * pad_id
436 for l in input_lengths
437 ]
438 batch_input_ids = [
439 torch.cat([pad, x]) for x, pad in zip(batch_input_ids, paddings)
440 ]
441 batch_input_ids = torch.stack(batch_input_ids)
442 batch_input_ids = batch_input_ids.cuda()
443
444 # specialization for HF
445 if early_stopping in [0, 1]:
446 local_early_stopping = bool(early_stopping)
447 else:
448 local_early_stopping = "never"
449
450 with torch.no_grad():
451 hf_config = {}
452 if num_beams == 1:
453 hf_config.update({
454 "top_k": top_k,
455 "top_p": top_p,
456 "do_sample": True,
457 })
458 else:
459 hf_config.update({
460 "num_beams": num_beams,
461 "early_stopping": local_early_stopping,
462 })
463
464 outputs = model.generate(batch_input_ids,
465 max_new_tokens=output_len,
466 num_return_sequences=num_sequences,
467 temperature=temperature,
468 eos_token_id=end_id,
469 pad_token_id=pad_id,
470 length_penalty=length_penalty,
471 output_scores=True,

Callers 1

mainFunction · 0.85

Calls 14

pplFunction · 0.90
maxFunction · 0.85
batch_decodeMethod · 0.80
sumMethod · 0.80
_prepare_inputsFunction · 0.70
modelFunction · 0.50
warningMethod · 0.45
sizeMethod · 0.45
updateMethod · 0.45
generateMethod · 0.45
emptyMethod · 0.45
viewMethod · 0.45

Tested by

no test coverage detected