Validate models. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics
(self, valid_iter, step=0)
| 168 | return total_stats |
| 169 | |
| 170 | def validate(self, valid_iter, step=0): |
| 171 | """ Validate models. |
| 172 | valid_iter: validate data iterator |
| 173 | Returns: |
| 174 | :obj:`nmt.Statistics`: validation loss statistics |
| 175 | """ |
| 176 | # Set models in validating mode. |
| 177 | self.model.eval() |
| 178 | stats = Statistics() |
| 179 | |
| 180 | with torch.no_grad(): |
| 181 | for batch in valid_iter: |
| 182 | src = batch.src |
| 183 | labels = batch.labels |
| 184 | segs = batch.segs |
| 185 | clss = batch.clss |
| 186 | mask = batch.mask |
| 187 | mask_cls = batch.mask_cls |
| 188 | |
| 189 | sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) |
| 190 | |
| 191 | loss = self.loss(sent_scores, labels.float()) |
| 192 | loss = (loss * mask.float()).sum() |
| 193 | batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels)) |
| 194 | stats.update(batch_stats) |
| 195 | self._report_step(0, step, valid_stats=stats) |
| 196 | return stats |
| 197 | |
| 198 | def test(self, test_iter, step, cal_lead=False, cal_oracle=False): |
| 199 | """ Validate models. |
nothing calls this directly
no test coverage detected