MCPcopy
hub / github.com/Turing-Project/WriteGPT / test

Method test

LanguageNetwork/BERT/models/trainer.py:198–305  ·  view source on GitHub ↗

Validate models. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics

(self, test_iter, step, cal_lead=False, cal_oracle=False)

Source from the content-addressed store, hash-verified

196 return stats
197
198 def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
199 """ Validate models.
200 valid_iter: validate data iterator
201 Returns:
202 :obj:`nmt.Statistics`: validation loss statistics
203 """
204
205 # Set models in validating mode.
206 def _get_ngrams(n, text):
207 ngram_set = set()
208 text_length = len(text)
209 max_index_ngram_start = text_length - n
210 for i in range(max_index_ngram_start + 1):
211 ngram_set.add(tuple(text[i:i + n]))
212 return ngram_set
213
214 def _block_tri(c, p):
215 tri_c = _get_ngrams(3, c.split())
216 for s in p:
217 tri_s = _get_ngrams(3, s.split())
218 if len(tri_c.intersection(tri_s)) > 0:
219 return True
220 return False
221
222 if not cal_lead and not cal_oracle:
223 self.model.eval()
224 stats = Statistics()
225
226 can_path = '%s_step%d.candidate' % (self.args.result_path + self.args.data_name, step)
227 gold_path = '%s_step%d.gold' % (self.args.result_path + self.args.data_name, step)
228 origin_path = '%s_step%d.origin' % (self.args.result_path + self.args.data_name, step)
229 with open(can_path, 'w', encoding='utf-8') as save_pred:
230 with open(gold_path, 'w', encoding='utf-8') as save_gold:
231 with torch.no_grad():
232 origin = []
233 for batch in test_iter:
234
235 src = batch.src # 7 sentences
236 # logger.info('origin sent: %s' % len(batch.src_str)) # 7 sentences
237
238 labels = batch.labels
239 segs = batch.segs
240 clss = batch.clss
241 mask = batch.mask
242 mask_cls = batch.mask_cls
243
244 gold = []
245 pred = []
246
247 if cal_lead:
248 selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
249 elif cal_oracle:
250 selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in
251 range(batch.batch_size)]
252 else:
253 sent_scores, mask, last_status = self.model(src, segs, clss, mask, mask_cls)
254
255 loss = self.loss(sent_scores, labels.float())

Callers

nothing calls this directly

Calls 7

updateMethod · 0.95
_report_stepMethod · 0.95
StatisticsClass · 0.90
save_txt_fileFunction · 0.90
test_rougeFunction · 0.90
rouge_results_to_strFunction · 0.90
writeMethod · 0.80

Tested by

no test coverage detected