MCPcopy
hub / github.com/appvision-ai/fast-bert / predict

Method predict

fast_bert/prediction_ner.py:127–168  ·  view source on GitHub ↗
(self, text, group=True, exclude_entities=["O"])

Source from the content-addressed store, hash-verified

125 self.model = load_model(Path(self.model_path) / model_name)
126
127 def predict(self, text, group=True, exclude_entities=["O"]):
128 # Inputs are provided through numpy array
129 tokens = self.tokenizer.tokenize(
130 self.tokenizer.decode(self.tokenizer.encode(text))
131 )
132
133 model_inputs = self.tokenizer(text, return_tensors="pt")
134 inputs_onnx = {k: v.cpu().detach().numpy() for k, v in model_inputs.items()}
135 outputs = self.model.run(None, inputs_onnx)[0]
136 outputs = softmax(outputs)
137
138 predictions = np.argmax(outputs, axis=2)
139
140 preds = [
141 (token, self.labels[prediction], output[prediction])
142 for token, output, prediction in zip(tokens, outputs[0], predictions[0])
143 ][1:-1]
144
145 preds = [
146 {
147 "index": index,
148 "word": prediction[0],
149 "entity": prediction[1],
150 "score": prediction[2],
151 }
152 for index, prediction in enumerate(preds)
153 ]
154
155 if group is True:
156 preds = group_entities(preds)
157
158 out_preds = []
159 for pred in preds:
160 if pred["entity"] not in exclude_entities:
161 try:
162 pred["entity"] = pred["entity"].split("-")[1]
163 except Exception:
164 pass
165
166 out_preds.append(pred)
167
168 return out_preds
169
170 def predict_batch(self, texts, group=True, exclude_entities=["O"]):
171 predictions = []

Callers 1

predict_batchMethod · 0.95

Calls 3

group_entitiesFunction · 0.85
softmaxFunction · 0.70
detachMethod · 0.45

Tested by

no test coverage detected