MCPcopy
hub / github.com/QData/TextAttack / evaluate

Method evaluate

textattack/trainer.py:883–927  ·  view source on GitHub ↗

Evaluate the model on given evaluation dataset.

(self)

Source from the content-addressed store, hash-verified

881 self._write_readme(best_eval_score, best_eval_score_epoch, train_batch_size)
882
883 def evaluate(self):
884 """Evaluate the model on given evaluation dataset."""
885
886 if not self.eval_dataset:
887 raise ValueError("No `eval_dataset` available for training.")
888
889 logging.info("Evaluating model on evaluation dataset.")
890 model = self.model_wrapper.model
891 tokenizer = self.model_wrapper.tokenizer
892
893 model.eval()
894 all_preds = []
895 all_targets = []
896
897 if isinstance(model, torch.nn.DataParallel):
898 num_gpus = torch.cuda.device_count()
899 eval_batch_size = self.training_args.per_device_eval_batch_size * num_gpus
900 else:
901 eval_batch_size = self.training_args.per_device_eval_batch_size
902
903 eval_dataloader = self.get_eval_dataloader(self.eval_dataset, eval_batch_size)
904
905 with torch.no_grad():
906 for step, batch in enumerate(eval_dataloader):
907 preds, targets = self.evaluate_step(model, tokenizer, batch)
908 all_preds.append(preds)
909 all_targets.append(targets)
910
911 preds = torch.cat(all_preds)
912 targets = torch.cat(all_targets)
913
914 if self.task_type == "regression":
915 pearson_correlation, pearson_p_value = scipy.stats.pearsonr(preds, targets)
916 eval_score = pearson_correlation
917 else:
918 correct_predictions = (preds == targets).sum().item()
919 accuracy = correct_predictions / len(targets)
920 eval_score = accuracy
921
922 if self._metric_name == "accuracy":
923 logger.info(f"Eval {self._metric_name}: {eval_score * 100:.2f}%")
924 else:
925 logger.info(f"Eval {self._metric_name}: {eval_score:.4f}%")
926
927 return eval_score
928
929 def _write_readme(self, best_eval_score, best_eval_score_epoch, train_batch_size):
930 if isinstance(self.training_args, CommandLineTrainingArgs):

Callers 1

trainMethod · 0.95

Calls 2

get_eval_dataloaderMethod · 0.95
evaluate_stepMethod · 0.95

Tested by

no test coverage detected