MCPcopy
hub / github.com/CLUEbenchmark/CLUE / predict

Function predict

baselines/models_pytorch/classifier_pytorch/run_classifier.py:221–293  ·  view source on GitHub ↗
(args, model, tokenizer, label_list, prefix="")

Source from the content-addressed store, hash-verified

219
220
221def predict(args, model, tokenizer, label_list, prefix=""):
222 pred_task_names = (args.task_name,)
223 pred_outputs_dirs = (args.output_dir,)
224 label_map = {i: label for i, label in enumerate(label_list)}
225
226 for pred_task, pred_output_dir in zip(pred_task_names, pred_outputs_dirs):
227 pred_dataset = load_and_cache_examples(args, pred_task, tokenizer, data_type='test')
228 if not os.path.exists(pred_output_dir) and args.local_rank in [-1, 0]:
229 os.makedirs(pred_output_dir)
230
231 args.pred_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
232 # Note that DistributedSampler samples randomly
233 pred_sampler = SequentialSampler(pred_dataset) if args.local_rank == -1 else DistributedSampler(pred_dataset)
234 pred_dataloader = DataLoader(pred_dataset, sampler=pred_sampler, batch_size=args.pred_batch_size,
235 collate_fn=xlnet_collate_fn
236 if args.model_type in ['xlnet'] else collate_fn)
237
238 logger.info("******** Running prediction {} ********".format(prefix))
239 logger.info(" Num examples = %d", len(pred_dataset))
240 logger.info(" Batch size = %d", args.pred_batch_size)
241 nb_pred_steps = 0
242 preds = None
243 pbar = ProgressBar(n_total=len(pred_dataloader), desc="Predicting")
244 for step, batch in enumerate(pred_dataloader):
245 model.eval()
246 batch = tuple(t.to(args.device) for t in batch)
247 with torch.no_grad():
248 inputs = {'input_ids': batch[0],
249 'attention_mask': batch[1],
250 'labels': batch[3]}
251 if args.model_type != 'distilbert':
252 inputs['token_type_ids'] = batch[2] if (
253 'bert' in args.model_type or 'xlnet' in args.model_type) else None # XLM, DistilBERT and RoBERTa don't use segment_ids
254 outputs = model(**inputs)
255 _, logits = outputs[:2]
256 nb_pred_steps += 1
257 if preds is None:
258 if pred_task == 'copa':
259 preds = logits.softmax(-1).detach().cpu().numpy()
260 else:
261 preds = logits.detach().cpu().numpy()
262 else:
263 if pred_task == 'copa':
264 preds = np.append(preds, logits.softmax(-1).detach().cpu().numpy(), axis=0)
265 else:
266 preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
267 pbar(step)
268 print(' ')
269 if args.output_mode == "classification":
270 predict_label = np.argmax(preds, axis=1)
271 elif args.output_mode == "regression":
272 predict_label = np.squeeze(preds)
273 if pred_task == 'copa':
274 predict_label = []
275 pred_logits = preds[:, 1]
276 i = 0
277 while (i < len(pred_logits) - 1):
278 if pred_logits[i] >= pred_logits[i + 1]:

Callers 1

mainFunction · 0.85

Calls 4

ProgressBarClass · 0.90
save_numpyFunction · 0.90
load_and_cache_examplesFunction · 0.85
joinMethod · 0.80

Tested by

no test coverage detected