input texts should be json objects
(self, texts)
| 70 | model.load_state_dict(checkpoint["state_dict"]) |
| 71 | |
| 72 | def predict(self, texts): |
| 73 | """ |
| 74 | input texts should be json objects |
| 75 | """ |
| 76 | with torch.no_grad(): |
| 77 | batch_texts = [self.dataset._get_vocab_id_list(json.loads(text)) for text in texts] |
| 78 | batch_texts = self.collate_fn(batch_texts) |
| 79 | logits = self.model(batch_texts) |
| 80 | if self.config.task_info.label_type != ClassificationType.MULTI_LABEL: |
| 81 | probs = torch.softmax(logits, dim=1) |
| 82 | else: |
| 83 | probs = torch.sigmoid(logits) |
| 84 | probs = probs.cpu().tolist() |
| 85 | return np.array(probs) |
| 86 | |
| 87 | if __name__ == "__main__": |
| 88 | config = Config(config_file=sys.argv[1]) |
no test coverage detected