MCPcopy
hub / github.com/appvision-ai/fast-bert / validate

Method validate

fast_bert/learner_qa.py:195–274  ·  view source on GitHub ↗
(self, prefix="SQUAD", n_best_size=20, max_answer_length=30, verbose_logging=False, null_score_diff_threshold=0.0)

Source from the content-addressed store, hash-verified

193
194 ### Evaluate the model
195 def validate(self, prefix="SQUAD", n_best_size=20, max_answer_length=30, verbose_logging=False, null_score_diff_threshold=0.0):
196 self.logger.info("Running evaluation")
197
198 self.logger.info(" Num examples = %d", len(self.data.val_dl.dataset))
199 self.logger.info(" Batch size = %d", self.data.val_batch_size)
200
201 all_logits = None
202 all_labels = None
203
204
205 eval_loss, eval_accuracy = 0, 0
206 nb_eval_steps, nb_eval_examples = 0, 0
207
208 preds = None
209 out_label_ids = None
210 all_results = []
211 for step, batch in enumerate(progress_bar(self.data.val_dl)):
212 self.model.eval()
213 batch = tuple(t.to(self.device) for t in batch)
214
215 with torch.no_grad():
216 inputs = {'input_ids': batch[0],
217 'attention_mask': batch[1],
218 'token_type_ids': None if self.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
219 }
220 example_indices = batch[3]
221 if self.model_type in ['xlnet', 'xlm']:
222 inputs.update({'cls_index': batch[4],
223 'p_mask': batch[5]})
224
225 outputs = self.model(**inputs)
226 tmp_eval_loss, logits = outputs[:2]
227 eval_loss += tmp_eval_loss.mean().item()
228
229 for i, example_index in enumerate(example_indices):
230 eval_feature = self.data.val_features[example_index.item()]
231 unique_id = int(eval_feature.unique_id)
232 if self.model_type in ['xlnet', 'xlm']:
233 # XLNet uses a more complex post-processing procedure
234 result = RawResultExtended(unique_id = unique_id,
235 start_top_log_probs = to_list(outputs[0][i]),
236 start_top_index = to_list(outputs[1][i]),
237 end_top_log_probs = to_list(outputs[2][i]),
238 end_top_index = to_list(outputs[3][i]),
239 cls_logits = to_list(outputs[4][i]))
240 else:
241 result = RawResult(unique_id = unique_id,
242 start_logits = to_list(outputs[0][i]),
243 end_logits = to_list(outputs[1][i]))
244 all_results.append(result)
245
246 nb_eval_steps += 1
247
248 # Compute predictions
249 output_prediction_file = os.path.join(self.validation_out, "predictions_{}.json".format(prefix))
250 output_nbest_file = os.path.join(self.validation_out, "nbest_predictions_{}.json".format(prefix))
251 if self.data.version_2_with_negative:
252 output_null_log_odds_file = os.path.join(self.validation_out, "null_odds_{}.json".format(prefix))

Callers 1

fitMethod · 0.95

Calls 4

to_listFunction · 0.85
write_predictionsFunction · 0.85
EVAL_OPTSClass · 0.85

Tested by

no test coverage detected