Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init `compute_metrics` argument). You can also subclass and override this method to inject custom behavi
(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
**gen_kwargs
)
| 29 | |
| 30 | class Seq2SeqTrainer(PrefixTrainer): |
| 31 | def evaluate( |
| 32 | self, |
| 33 | eval_dataset: Optional[Dataset] = None, |
| 34 | ignore_keys: Optional[List[str]] = None, |
| 35 | metric_key_prefix: str = "eval", |
| 36 | **gen_kwargs |
| 37 | ) -> Dict[str, float]: |
| 38 | """ |
| 39 | Run evaluation and returns metrics. |
| 40 | |
| 41 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent |
| 42 | (pass it to the init `compute_metrics` argument). |
| 43 | |
| 44 | You can also subclass and override this method to inject custom behavior. |
| 45 | |
| 46 | Args: |
| 47 | eval_dataset (`Dataset`, *optional*): |
| 48 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns |
| 49 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` |
| 50 | method. |
| 51 | ignore_keys (`List[str]`, *optional*): |
| 52 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when |
| 53 | gathering predictions. |
| 54 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): |
| 55 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named |
| 56 | "eval_bleu" if the prefix is `"eval"` (default) |
| 57 | max_length (`int`, *optional*): |
| 58 | The maximum target length to use when predicting with the generate method. |
| 59 | num_beams (`int`, *optional*): |
| 60 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no |
| 61 | beam search. |
| 62 | gen_kwargs: |
| 63 | Additional `generate` specific kwargs. |
| 64 | |
| 65 | Returns: |
| 66 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The |
| 67 | dictionary also contains the epoch number which comes from the training state. |
| 68 | """ |
| 69 | |
| 70 | gen_kwargs = gen_kwargs.copy() |
| 71 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: |
| 72 | gen_kwargs["max_length"] = self.args.generation_max_length |
| 73 | gen_kwargs["num_beams"] = ( |
| 74 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams |
| 75 | ) |
| 76 | self._gen_kwargs = gen_kwargs |
| 77 | |
| 78 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) |
| 79 | |
| 80 | def predict( |
| 81 | self, |