Validate models. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics
(self, test_iter, step, cal_lead=False, cal_oracle=False)
| 196 | return stats |
| 197 | |
| 198 | def test(self, test_iter, step, cal_lead=False, cal_oracle=False): |
| 199 | """ Validate models. |
| 200 | valid_iter: validate data iterator |
| 201 | Returns: |
| 202 | :obj:`nmt.Statistics`: validation loss statistics |
| 203 | """ |
| 204 | |
| 205 | # Set models in validating mode. |
| 206 | def _get_ngrams(n, text): |
| 207 | ngram_set = set() |
| 208 | text_length = len(text) |
| 209 | max_index_ngram_start = text_length - n |
| 210 | for i in range(max_index_ngram_start + 1): |
| 211 | ngram_set.add(tuple(text[i:i + n])) |
| 212 | return ngram_set |
| 213 | |
| 214 | def _block_tri(c, p): |
| 215 | tri_c = _get_ngrams(3, c.split()) |
| 216 | for s in p: |
| 217 | tri_s = _get_ngrams(3, s.split()) |
| 218 | if len(tri_c.intersection(tri_s)) > 0: |
| 219 | return True |
| 220 | return False |
| 221 | |
| 222 | if not cal_lead and not cal_oracle: |
| 223 | self.model.eval() |
| 224 | stats = Statistics() |
| 225 | |
| 226 | can_path = '%s_step%d.candidate' % (self.args.result_path + self.args.data_name, step) |
| 227 | gold_path = '%s_step%d.gold' % (self.args.result_path + self.args.data_name, step) |
| 228 | origin_path = '%s_step%d.origin' % (self.args.result_path + self.args.data_name, step) |
| 229 | with open(can_path, 'w', encoding='utf-8') as save_pred: |
| 230 | with open(gold_path, 'w', encoding='utf-8') as save_gold: |
| 231 | with torch.no_grad(): |
| 232 | origin = [] |
| 233 | for batch in test_iter: |
| 234 | |
| 235 | src = batch.src # 7 sentences |
| 236 | # logger.info('origin sent: %s' % len(batch.src_str)) # 7 sentences |
| 237 | |
| 238 | labels = batch.labels |
| 239 | segs = batch.segs |
| 240 | clss = batch.clss |
| 241 | mask = batch.mask |
| 242 | mask_cls = batch.mask_cls |
| 243 | |
| 244 | gold = [] |
| 245 | pred = [] |
| 246 | |
| 247 | if cal_lead: |
| 248 | selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size |
| 249 | elif cal_oracle: |
| 250 | selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in |
| 251 | range(batch.batch_size)] |
| 252 | else: |
| 253 | sent_scores, mask, last_status = self.model(src, segs, clss, mask, mask_cls) |
| 254 | |
| 255 | loss = self.loss(sent_scores, labels.float()) |
nothing calls this directly
no test coverage detected