(self, texts=None, verbose=True)
| 635 | |
| 636 | ### Return Predictions ### |
| 637 | def predict_batch(self, texts=None, verbose=True): |
| 638 | |
| 639 | if verbose: |
| 640 | if self.logger is None: |
| 641 | self.logger = logging.getLogger(__name__) |
| 642 | if texts: |
| 643 | if verbose: |
| 644 | self.logger.info("---PROGRESS-STATUS---: Tokenizing input texts...") |
| 645 | dl = self.data.get_dl_from_texts(texts) |
| 646 | if verbose: |
| 647 | self.logger.info("---PROGRESS-STATUS---: Tokenizing input texts...DONE") |
| 648 | elif self.data.test_dl: |
| 649 | dl = self.data.test_dl |
| 650 | else: |
| 651 | dl = self.data.val_dl |
| 652 | |
| 653 | all_logits = None |
| 654 | |
| 655 | self.model.eval() |
| 656 | for step, batch in enumerate(dl): |
| 657 | if verbose: |
| 658 | self.logger.info( |
| 659 | "---PROGRESS-STATUS---: Predicting batch {}/{}".format( |
| 660 | step + 1, len(dl) |
| 661 | ) |
| 662 | ) |
| 663 | batch = tuple(t.to(self.device) for t in batch) |
| 664 | |
| 665 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": None} |
| 666 | |
| 667 | if self.model_type in ["bert", "xlnet"]: |
| 668 | inputs["token_type_ids"] = batch[2] |
| 669 | |
| 670 | with torch.no_grad(): |
| 671 | outputs = self.model(**inputs) |
| 672 | logits = outputs[0] |
| 673 | if self.multi_label: |
| 674 | logits = logits.sigmoid() |
| 675 | # elif len(self.data.labels) == 2: |
| 676 | # logits = logits.sigmoid() |
| 677 | else: |
| 678 | logits = logits.softmax(dim=1) |
| 679 | |
| 680 | if all_logits is None: |
| 681 | all_logits = logits.detach().cpu().numpy() |
| 682 | else: |
| 683 | all_logits = np.concatenate( |
| 684 | (all_logits, logits.detach().cpu().numpy()), axis=0 |
| 685 | ) |
| 686 | |
| 687 | result_df = pd.DataFrame(all_logits, columns=self.data.labels) |
| 688 | results = result_df.to_dict(orient="records") |
| 689 | |
| 690 | if verbose: |
| 691 | self.logger.info("---PROGRESS-STATUS---: Predicting batch...DONE") |
| 692 | return [sorted(x.items(), key=lambda kv: kv[1], reverse=True) for x in results] |
| 693 | |
| 694 | # Begin code for LR Finder |
nothing calls this directly
no test coverage detected