(datapoint,
eval_task='summarize',
eval_ppl=False,
add_special_tokens=True,
min_input_length=0)
| 412 | return [], [], [], {} |
| 413 | |
| 414 | def eval_hf(datapoint, |
| 415 | eval_task='summarize', |
| 416 | eval_ppl=False, |
| 417 | add_special_tokens=True, |
| 418 | min_input_length=0): |
| 419 | batch_size = len(datapoint[dataset_input_key]) |
| 420 | if batch_size > 1: |
| 421 | logger.warning( |
| 422 | f"HF does not support batch_size > 1 to verify correctness due to padding. Current batch size is {batch_size}" |
| 423 | ) |
| 424 | batch_input_ids = _prepare_inputs(datapoint[dataset_input_key], |
| 425 | eval_task=eval_task, |
| 426 | add_special_tokens=add_special_tokens, |
| 427 | min_input_length=min_input_length) |
| 428 | batch_size = len(batch_input_ids) |
| 429 | if batch_size == 0: |
| 430 | return [], [], [], [[] for _ in range(batch_size)] |
| 431 | input_lengths = [x.size(0) for x in batch_input_ids] |
| 432 | # Left padding for HF |
| 433 | max_length = max(input_lengths) |
| 434 | paddings = [ |
| 435 | torch.ones(max_length - l, dtype=torch.int32) * pad_id |
| 436 | for l in input_lengths |
| 437 | ] |
| 438 | batch_input_ids = [ |
| 439 | torch.cat([pad, x]) for x, pad in zip(batch_input_ids, paddings) |
| 440 | ] |
| 441 | batch_input_ids = torch.stack(batch_input_ids) |
| 442 | batch_input_ids = batch_input_ids.cuda() |
| 443 | |
| 444 | # specialization for HF |
| 445 | if early_stopping in [0, 1]: |
| 446 | local_early_stopping = bool(early_stopping) |
| 447 | else: |
| 448 | local_early_stopping = "never" |
| 449 | |
| 450 | with torch.no_grad(): |
| 451 | hf_config = {} |
| 452 | if num_beams == 1: |
| 453 | hf_config.update({ |
| 454 | "top_k": top_k, |
| 455 | "top_p": top_p, |
| 456 | "do_sample": True, |
| 457 | }) |
| 458 | else: |
| 459 | hf_config.update({ |
| 460 | "num_beams": num_beams, |
| 461 | "early_stopping": local_early_stopping, |
| 462 | }) |
| 463 | |
| 464 | outputs = model.generate(batch_input_ids, |
| 465 | max_new_tokens=output_len, |
| 466 | num_return_sequences=num_sequences, |
| 467 | temperature=temperature, |
| 468 | eos_token_id=end_id, |
| 469 | pad_token_id=pad_id, |
| 470 | length_penalty=length_penalty, |
| 471 | output_scores=True, |
no test coverage detected