(self, prefix="SQUAD", n_best_size=20, max_answer_length=30, verbose_logging=False, null_score_diff_threshold=0.0)
| 193 | |
| 194 | ### Evaluate the model |
| 195 | def validate(self, prefix="SQUAD", n_best_size=20, max_answer_length=30, verbose_logging=False, null_score_diff_threshold=0.0): |
| 196 | self.logger.info("Running evaluation") |
| 197 | |
| 198 | self.logger.info(" Num examples = %d", len(self.data.val_dl.dataset)) |
| 199 | self.logger.info(" Batch size = %d", self.data.val_batch_size) |
| 200 | |
| 201 | all_logits = None |
| 202 | all_labels = None |
| 203 | |
| 204 | |
| 205 | eval_loss, eval_accuracy = 0, 0 |
| 206 | nb_eval_steps, nb_eval_examples = 0, 0 |
| 207 | |
| 208 | preds = None |
| 209 | out_label_ids = None |
| 210 | all_results = [] |
| 211 | for step, batch in enumerate(progress_bar(self.data.val_dl)): |
| 212 | self.model.eval() |
| 213 | batch = tuple(t.to(self.device) for t in batch) |
| 214 | |
| 215 | with torch.no_grad(): |
| 216 | inputs = {'input_ids': batch[0], |
| 217 | 'attention_mask': batch[1], |
| 218 | 'token_type_ids': None if self.model_type == 'xlm' else batch[2] # XLM don't use segment_ids |
| 219 | } |
| 220 | example_indices = batch[3] |
| 221 | if self.model_type in ['xlnet', 'xlm']: |
| 222 | inputs.update({'cls_index': batch[4], |
| 223 | 'p_mask': batch[5]}) |
| 224 | |
| 225 | outputs = self.model(**inputs) |
| 226 | tmp_eval_loss, logits = outputs[:2] |
| 227 | eval_loss += tmp_eval_loss.mean().item() |
| 228 | |
| 229 | for i, example_index in enumerate(example_indices): |
| 230 | eval_feature = self.data.val_features[example_index.item()] |
| 231 | unique_id = int(eval_feature.unique_id) |
| 232 | if self.model_type in ['xlnet', 'xlm']: |
| 233 | # XLNet uses a more complex post-processing procedure |
| 234 | result = RawResultExtended(unique_id = unique_id, |
| 235 | start_top_log_probs = to_list(outputs[0][i]), |
| 236 | start_top_index = to_list(outputs[1][i]), |
| 237 | end_top_log_probs = to_list(outputs[2][i]), |
| 238 | end_top_index = to_list(outputs[3][i]), |
| 239 | cls_logits = to_list(outputs[4][i])) |
| 240 | else: |
| 241 | result = RawResult(unique_id = unique_id, |
| 242 | start_logits = to_list(outputs[0][i]), |
| 243 | end_logits = to_list(outputs[1][i])) |
| 244 | all_results.append(result) |
| 245 | |
| 246 | nb_eval_steps += 1 |
| 247 | |
| 248 | # Compute predictions |
| 249 | output_prediction_file = os.path.join(self.validation_out, "predictions_{}.json".format(prefix)) |
| 250 | output_nbest_file = os.path.join(self.validation_out, "nbest_predictions_{}.json".format(prefix)) |
| 251 | if self.data.version_2_with_negative: |
| 252 | output_null_log_odds_file = os.path.join(self.validation_out, "null_odds_{}.json".format(prefix)) |
no test coverage detected