MCPcopy
hub / github.com/appvision-ai/fast-bert / predict_batch

Method predict_batch

fast_bert/learner_cls copy.py:637–692  ·  view source on GitHub ↗
(self, texts=None, verbose=True)

Source from the content-addressed store, hash-verified

635
636 ### Return Predictions ###
637 def predict_batch(self, texts=None, verbose=True):
638
639 if verbose:
640 if self.logger is None:
641 self.logger = logging.getLogger(__name__)
642 if texts:
643 if verbose:
644 self.logger.info("---PROGRESS-STATUS---: Tokenizing input texts...")
645 dl = self.data.get_dl_from_texts(texts)
646 if verbose:
647 self.logger.info("---PROGRESS-STATUS---: Tokenizing input texts...DONE")
648 elif self.data.test_dl:
649 dl = self.data.test_dl
650 else:
651 dl = self.data.val_dl
652
653 all_logits = None
654
655 self.model.eval()
656 for step, batch in enumerate(dl):
657 if verbose:
658 self.logger.info(
659 "---PROGRESS-STATUS---: Predicting batch {}/{}".format(
660 step + 1, len(dl)
661 )
662 )
663 batch = tuple(t.to(self.device) for t in batch)
664
665 inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": None}
666
667 if self.model_type in ["bert", "xlnet"]:
668 inputs["token_type_ids"] = batch[2]
669
670 with torch.no_grad():
671 outputs = self.model(**inputs)
672 logits = outputs[0]
673 if self.multi_label:
674 logits = logits.sigmoid()
675 # elif len(self.data.labels) == 2:
676 # logits = logits.sigmoid()
677 else:
678 logits = logits.softmax(dim=1)
679
680 if all_logits is None:
681 all_logits = logits.detach().cpu().numpy()
682 else:
683 all_logits = np.concatenate(
684 (all_logits, logits.detach().cpu().numpy()), axis=0
685 )
686
687 result_df = pd.DataFrame(all_logits, columns=self.data.labels)
688 results = result_df.to_dict(orient="records")
689
690 if verbose:
691 self.logger.info("---PROGRESS-STATUS---: Predicting batch...DONE")
692 return [sorted(x.items(), key=lambda kv: kv[1], reverse=True) for x in results]
693
694 # Begin code for LR Finder

Callers

nothing calls this directly

Calls 2

get_dl_from_textsMethod · 0.45
detachMethod · 0.45

Tested by

no test coverage detected