| 556 | |
| 557 | ### Evaluate the model |
| 558 | def validate(self, quiet=False, loss_only=False, return_preds=False): |
| 559 | if quiet is False: |
| 560 | self.logger.info("Running evaluation") |
| 561 | self.logger.info(" Num examples = %d", len(self.data.val_dl.dataset)) |
| 562 | self.logger.info(" Batch size = %d", self.data.val_batch_size) |
| 563 | |
| 564 | all_logits = None |
| 565 | all_labels = None |
| 566 | |
| 567 | eval_loss = 0 |
| 568 | nb_eval_steps, nb_eval_examples = 0, 0 |
| 569 | |
| 570 | preds = None |
| 571 | out_label_ids = None |
| 572 | |
| 573 | validation_scores = {metric["name"]: 0.0 for metric in self.metrics} |
| 574 | |
| 575 | iterator = self.data.val_dl if quiet else progress_bar(self.data.val_dl) |
| 576 | |
| 577 | for step, batch in enumerate(iterator): |
| 578 | self.model.eval() |
| 579 | batch = tuple(t.to(self.device) for t in batch) |
| 580 | |
| 581 | with torch.no_grad(): |
| 582 | inputs = { |
| 583 | "input_ids": batch[0], |
| 584 | "attention_mask": batch[1], |
| 585 | "labels": batch[3], |
| 586 | } |
| 587 | |
| 588 | if self.model_type in ["bert", "xlnet"]: |
| 589 | inputs["token_type_ids"] = batch[2] |
| 590 | |
| 591 | outputs = self.model(**inputs) |
| 592 | tmp_eval_loss, logits = outputs[:2] |
| 593 | |
| 594 | eval_loss += tmp_eval_loss.mean().item() |
| 595 | |
| 596 | nb_eval_steps += 1 |
| 597 | nb_eval_examples += inputs["input_ids"].size(0) |
| 598 | |
| 599 | if all_logits is None: |
| 600 | all_logits = logits |
| 601 | else: |
| 602 | all_logits = torch.cat((all_logits, logits), 0) |
| 603 | |
| 604 | if all_labels is None: |
| 605 | all_labels = inputs["labels"] |
| 606 | else: |
| 607 | all_labels = torch.cat((all_labels, inputs["labels"]), 0) |
| 608 | |
| 609 | if preds is None: |
| 610 | preds = logits.detach().cpu().numpy() |
| 611 | out_label_ids = inputs["labels"].detach().cpu().numpy() |
| 612 | else: |
| 613 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) |
| 614 | out_label_ids = np.append( |
| 615 | out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0 |