MCPcopy
hub / github.com/Turing-Project/WriteGPT / validate

Method validate

LanguageNetwork/BERT/models/trainer.py:170–196  ·  view source on GitHub ↗

Validate models. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics

(self, valid_iter, step=0)

Source from the content-addressed store, hash-verified

168 return total_stats
169
170 def validate(self, valid_iter, step=0):
171 """ Validate models.
172 valid_iter: validate data iterator
173 Returns:
174 :obj:`nmt.Statistics`: validation loss statistics
175 """
176 # Set models in validating mode.
177 self.model.eval()
178 stats = Statistics()
179
180 with torch.no_grad():
181 for batch in valid_iter:
182 src = batch.src
183 labels = batch.labels
184 segs = batch.segs
185 clss = batch.clss
186 mask = batch.mask
187 mask_cls = batch.mask_cls
188
189 sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)
190
191 loss = self.loss(sent_scores, labels.float())
192 loss = (loss * mask.float()).sum()
193 batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
194 stats.update(batch_stats)
195 self._report_step(0, step, valid_stats=stats)
196 return stats
197
198 def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
199 """ Validate models.

Callers

nothing calls this directly

Calls 3

updateMethod · 0.95
_report_stepMethod · 0.95
StatisticsClass · 0.90

Tested by

no test coverage detected