MCPcopy
hub / github.com/appvision-ai/fast-bert / validate

Method validate

fast_bert/learner_cls copy.py:558–634  ·  view source on GitHub ↗
(self, quiet=False, loss_only=False, return_preds=False)

Source from the content-addressed store, hash-verified

556
557 ### Evaluate the model
558 def validate(self, quiet=False, loss_only=False, return_preds=False):
559 if quiet is False:
560 self.logger.info("Running evaluation")
561 self.logger.info(" Num examples = %d", len(self.data.val_dl.dataset))
562 self.logger.info(" Batch size = %d", self.data.val_batch_size)
563
564 all_logits = None
565 all_labels = None
566
567 eval_loss = 0
568 nb_eval_steps, nb_eval_examples = 0, 0
569
570 preds = None
571 out_label_ids = None
572
573 validation_scores = {metric["name"]: 0.0 for metric in self.metrics}
574
575 iterator = self.data.val_dl if quiet else progress_bar(self.data.val_dl)
576
577 for step, batch in enumerate(iterator):
578 self.model.eval()
579 batch = tuple(t.to(self.device) for t in batch)
580
581 with torch.no_grad():
582 inputs = {
583 "input_ids": batch[0],
584 "attention_mask": batch[1],
585 "labels": batch[3],
586 }
587
588 if self.model_type in ["bert", "xlnet"]:
589 inputs["token_type_ids"] = batch[2]
590
591 outputs = self.model(**inputs)
592 tmp_eval_loss, logits = outputs[:2]
593
594 eval_loss += tmp_eval_loss.mean().item()
595
596 nb_eval_steps += 1
597 nb_eval_examples += inputs["input_ids"].size(0)
598
599 if all_logits is None:
600 all_logits = logits
601 else:
602 all_logits = torch.cat((all_logits, logits), 0)
603
604 if all_labels is None:
605 all_labels = inputs["labels"]
606 else:
607 all_labels = torch.cat((all_labels, inputs["labels"]), 0)
608
609 if preds is None:
610 preds = logits.detach().cpu().numpy()
611 out_label_ids = inputs["labels"].detach().cpu().numpy()
612 else:
613 preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
614 out_label_ids = np.append(
615 out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0

Callers 2

fitMethod · 0.95
lr_findMethod · 0.95

Calls 1

detachMethod · 0.45

Tested by

no test coverage detected