(args, model, tokenizer, label_list, prefix="")
| 219 | |
| 220 | |
| 221 | def predict(args, model, tokenizer, label_list, prefix=""): |
| 222 | pred_task_names = (args.task_name,) |
| 223 | pred_outputs_dirs = (args.output_dir,) |
| 224 | label_map = {i: label for i, label in enumerate(label_list)} |
| 225 | |
| 226 | for pred_task, pred_output_dir in zip(pred_task_names, pred_outputs_dirs): |
| 227 | pred_dataset = load_and_cache_examples(args, pred_task, tokenizer, data_type='test') |
| 228 | if not os.path.exists(pred_output_dir) and args.local_rank in [-1, 0]: |
| 229 | os.makedirs(pred_output_dir) |
| 230 | |
| 231 | args.pred_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) |
| 232 | # Note that DistributedSampler samples randomly |
| 233 | pred_sampler = SequentialSampler(pred_dataset) if args.local_rank == -1 else DistributedSampler(pred_dataset) |
| 234 | pred_dataloader = DataLoader(pred_dataset, sampler=pred_sampler, batch_size=args.pred_batch_size, |
| 235 | collate_fn=xlnet_collate_fn |
| 236 | if args.model_type in ['xlnet'] else collate_fn) |
| 237 | |
| 238 | logger.info("******** Running prediction {} ********".format(prefix)) |
| 239 | logger.info(" Num examples = %d", len(pred_dataset)) |
| 240 | logger.info(" Batch size = %d", args.pred_batch_size) |
| 241 | nb_pred_steps = 0 |
| 242 | preds = None |
| 243 | pbar = ProgressBar(n_total=len(pred_dataloader), desc="Predicting") |
| 244 | for step, batch in enumerate(pred_dataloader): |
| 245 | model.eval() |
| 246 | batch = tuple(t.to(args.device) for t in batch) |
| 247 | with torch.no_grad(): |
| 248 | inputs = {'input_ids': batch[0], |
| 249 | 'attention_mask': batch[1], |
| 250 | 'labels': batch[3]} |
| 251 | if args.model_type != 'distilbert': |
| 252 | inputs['token_type_ids'] = batch[2] if ( |
| 253 | 'bert' in args.model_type or 'xlnet' in args.model_type) else None # XLM, DistilBERT and RoBERTa don't use segment_ids |
| 254 | outputs = model(**inputs) |
| 255 | _, logits = outputs[:2] |
| 256 | nb_pred_steps += 1 |
| 257 | if preds is None: |
| 258 | if pred_task == 'copa': |
| 259 | preds = logits.softmax(-1).detach().cpu().numpy() |
| 260 | else: |
| 261 | preds = logits.detach().cpu().numpy() |
| 262 | else: |
| 263 | if pred_task == 'copa': |
| 264 | preds = np.append(preds, logits.softmax(-1).detach().cpu().numpy(), axis=0) |
| 265 | else: |
| 266 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) |
| 267 | pbar(step) |
| 268 | print(' ') |
| 269 | if args.output_mode == "classification": |
| 270 | predict_label = np.argmax(preds, axis=1) |
| 271 | elif args.output_mode == "regression": |
| 272 | predict_label = np.squeeze(preds) |
| 273 | if pred_task == 'copa': |
| 274 | predict_label = [] |
| 275 | pred_logits = preds[:, 1] |
| 276 | i = 0 |
| 277 | while (i < len(pred_logits) - 1): |
| 278 | if pred_logits[i] >= pred_logits[i + 1]: |
no test coverage detected