(self, text, group=True, exclude_entities=["O"])
| 125 | self.model = load_model(Path(self.model_path) / model_name) |
| 126 | |
| 127 | def predict(self, text, group=True, exclude_entities=["O"]): |
| 128 | # Inputs are provided through numpy array |
| 129 | tokens = self.tokenizer.tokenize( |
| 130 | self.tokenizer.decode(self.tokenizer.encode(text)) |
| 131 | ) |
| 132 | |
| 133 | model_inputs = self.tokenizer(text, return_tensors="pt") |
| 134 | inputs_onnx = {k: v.cpu().detach().numpy() for k, v in model_inputs.items()} |
| 135 | outputs = self.model.run(None, inputs_onnx)[0] |
| 136 | outputs = softmax(outputs) |
| 137 | |
| 138 | predictions = np.argmax(outputs, axis=2) |
| 139 | |
| 140 | preds = [ |
| 141 | (token, self.labels[prediction], output[prediction]) |
| 142 | for token, output, prediction in zip(tokens, outputs[0], predictions[0]) |
| 143 | ][1:-1] |
| 144 | |
| 145 | preds = [ |
| 146 | { |
| 147 | "index": index, |
| 148 | "word": prediction[0], |
| 149 | "entity": prediction[1], |
| 150 | "score": prediction[2], |
| 151 | } |
| 152 | for index, prediction in enumerate(preds) |
| 153 | ] |
| 154 | |
| 155 | if group is True: |
| 156 | preds = group_entities(preds) |
| 157 | |
| 158 | out_preds = [] |
| 159 | for pred in preds: |
| 160 | if pred["entity"] not in exclude_entities: |
| 161 | try: |
| 162 | pred["entity"] = pred["entity"].split("-")[1] |
| 163 | except Exception: |
| 164 | pass |
| 165 | |
| 166 | out_preds.append(pred) |
| 167 | |
| 168 | return out_preds |
| 169 | |
| 170 | def predict_batch(self, texts, group=True, exclude_entities=["O"]): |
| 171 | predictions = [] |
no test coverage detected