MCPcopy Index your code
hub / github.com/Tencent/NeuralNLP-NeuralClassifier / predict

Method predict

predict.py:72–85  ·  view source on GitHub ↗

input texts should be json objects

(self, texts)

Source from the content-addressed store, hash-verified

70 model.load_state_dict(checkpoint["state_dict"])
71
72 def predict(self, texts):
73 """
74 input texts should be json objects
75 """
76 with torch.no_grad():
77 batch_texts = [self.dataset._get_vocab_id_list(json.loads(text)) for text in texts]
78 batch_texts = self.collate_fn(batch_texts)
79 logits = self.model(batch_texts)
80 if self.config.task_info.label_type != ClassificationType.MULTI_LABEL:
81 probs = torch.softmax(logits, dim=1)
82 else:
83 probs = torch.sigmoid(logits)
84 probs = probs.cpu().tolist()
85 return np.array(probs)
86
87if __name__ == "__main__":
88 config = Config(config_file=sys.argv[1])

Callers 1

predict.pyFile · 0.80

Calls 1

_get_vocab_id_listMethod · 0.45

Tested by

no test coverage detected