MCPcopy
hub / github.com/Turing-Project/WriteGPT / predict

Method predict

LanguageNetwork/BERT/models/trainer.py:307–418  ·  view source on GitHub ↗

Validate models. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics

(self, test_iter, step, cal_lead=False, cal_oracle=False)

Source from the content-addressed store, hash-verified

305 return stats
306
307 def predict(self, test_iter, step, cal_lead=False, cal_oracle=False):
308 """ Validate models.
309 valid_iter: validate data iterator
310 Returns:
311 :obj:`nmt.Statistics`: validation loss statistics
312 """
313
314 # Set models in validating mode.
315 def _get_ngrams(n, text):
316 ngram_set = set()
317 text_length = len(text)
318 max_index_ngram_start = text_length - n
319 for i in range(max_index_ngram_start + 1):
320 ngram_set.add(tuple(text[i:i + n]))
321 return ngram_set
322
323 def _block_tri(c, p):
324 tri_c = _get_ngrams(3, c.split())
325 for s in p:
326 tri_s = _get_ngrams(3, s.split())
327 if len(tri_c.intersection(tri_s)) > 0:
328 return True
329 return False
330
331 if not cal_lead and not cal_oracle:
332 self.model.eval()
333 stats = Statistics()
334
335 can_path = '%s_step%d.candidate' % (self.args.result_path + self.args.data_name, step)
336 gold_path = '%s_step%d.gold' % (self.args.result_path + self.args.data_name, step)
337 origin_path = '%s_step%d.origin' % (self.args.result_path + self.args.data_name, step)
338 with open(can_path, 'w', encoding='utf-8') as save_pred:
339 with open(gold_path, 'w', encoding='utf-8') as save_gold:
340 with torch.no_grad():
341 origin = []
342 for batch in test_iter:
343
344 src = batch.src # 7 sentences
345 # logger.info('origin sent: %s' % len(batch.src_str)) # 7 sentences
346
347 labels = batch.labels
348 segs = batch.segs
349 clss = batch.clss
350 mask = batch.mask
351 mask_cls = batch.mask_cls
352
353 gold = []
354 pred = []
355
356 if cal_lead:
357 selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
358 elif cal_oracle:
359 selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in
360 range(batch.batch_size)]
361 else:
362 sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)
363
364 loss = self.loss(sent_scores, labels.float())

Callers

nothing calls this directly

Calls 7

updateMethod · 0.95
_report_stepMethod · 0.95
StatisticsClass · 0.90
save_txt_fileFunction · 0.90
test_rougeFunction · 0.90
rouge_results_to_strFunction · 0.90
writeMethod · 0.80

Tested by

no test coverage detected