Evaluate the model on given evaluation dataset.
(self)
| 881 | self._write_readme(best_eval_score, best_eval_score_epoch, train_batch_size) |
| 882 | |
| 883 | def evaluate(self): |
| 884 | """Evaluate the model on given evaluation dataset.""" |
| 885 | |
| 886 | if not self.eval_dataset: |
| 887 | raise ValueError("No `eval_dataset` available for training.") |
| 888 | |
| 889 | logging.info("Evaluating model on evaluation dataset.") |
| 890 | model = self.model_wrapper.model |
| 891 | tokenizer = self.model_wrapper.tokenizer |
| 892 | |
| 893 | model.eval() |
| 894 | all_preds = [] |
| 895 | all_targets = [] |
| 896 | |
| 897 | if isinstance(model, torch.nn.DataParallel): |
| 898 | num_gpus = torch.cuda.device_count() |
| 899 | eval_batch_size = self.training_args.per_device_eval_batch_size * num_gpus |
| 900 | else: |
| 901 | eval_batch_size = self.training_args.per_device_eval_batch_size |
| 902 | |
| 903 | eval_dataloader = self.get_eval_dataloader(self.eval_dataset, eval_batch_size) |
| 904 | |
| 905 | with torch.no_grad(): |
| 906 | for step, batch in enumerate(eval_dataloader): |
| 907 | preds, targets = self.evaluate_step(model, tokenizer, batch) |
| 908 | all_preds.append(preds) |
| 909 | all_targets.append(targets) |
| 910 | |
| 911 | preds = torch.cat(all_preds) |
| 912 | targets = torch.cat(all_targets) |
| 913 | |
| 914 | if self.task_type == "regression": |
| 915 | pearson_correlation, pearson_p_value = scipy.stats.pearsonr(preds, targets) |
| 916 | eval_score = pearson_correlation |
| 917 | else: |
| 918 | correct_predictions = (preds == targets).sum().item() |
| 919 | accuracy = correct_predictions / len(targets) |
| 920 | eval_score = accuracy |
| 921 | |
| 922 | if self._metric_name == "accuracy": |
| 923 | logger.info(f"Eval {self._metric_name}: {eval_score * 100:.2f}%") |
| 924 | else: |
| 925 | logger.info(f"Eval {self._metric_name}: {eval_score:.4f}%") |
| 926 | |
| 927 | return eval_score |
| 928 | |
| 929 | def _write_readme(self, best_eval_score, best_eval_score_epoch, train_batch_size): |
| 930 | if isinstance(self.training_args, CommandLineTrainingArgs): |
no test coverage detected