Evaluate the model and log appropriate metrics. Args: evaluator (Evaluator): The evaluator to use for evaluation. metrics (dict[str, Metric]): Dictionary mapping metric names to metrics to evaluate against. subset_num_batches (int, optional): If specified
(
self,
evaluator: Evaluator,
metrics: dict[str, Metric],
subset_num_batches: Optional[int] = None,
)
| 3312 | self.engine.run_event(Event.EVAL_STANDALONE_END) |
| 3313 | |
| 3314 | def _eval_loop( |
| 3315 | self, |
| 3316 | evaluator: Evaluator, |
| 3317 | metrics: dict[str, Metric], |
| 3318 | subset_num_batches: Optional[int] = None, |
| 3319 | ): |
| 3320 | """Evaluate the model and log appropriate metrics. |
| 3321 | |
| 3322 | Args: |
| 3323 | evaluator (Evaluator): The evaluator to use for evaluation. |
| 3324 | metrics (dict[str, Metric]): Dictionary mapping metric names to metrics to evaluate against. |
| 3325 | subset_num_batches (int, optional): If specified, evaluate on this many batches. Defaults to ``-1``, |
| 3326 | which means to iterate over the entire dataloader. |
| 3327 | """ |
| 3328 | if subset_num_batches is None: |
| 3329 | subset_num_batches = -1 |
| 3330 | |
| 3331 | # back up the original dataloader on the state, so we can restore it after evaluation is finished |
| 3332 | original_dataloader = self.state.dataloader |
| 3333 | original_dataloader_label = self.state.dataloader_label |
| 3334 | original_num_batches = self.state.dataloader_len |
| 3335 | |
| 3336 | # Unpack data_spec |
| 3337 | data_spec = evaluator.dataloader |
| 3338 | |
| 3339 | # Reset the eval timestamp |
| 3340 | self.state.eval_timestamp = Timestamp() |
| 3341 | |
| 3342 | last_wct = datetime.datetime.now() |
| 3343 | |
| 3344 | with torch.no_grad(), model_eval_mode(self.state.model): |
| 3345 | self.state.set_dataloader(data_spec.dataloader, evaluator.label, subset_num_batches) |
| 3346 | assert self.state.dataloader is not None, 'dataloader is set' |
| 3347 | |
| 3348 | self.engine.run_event(Event.EVAL_START) |
| 3349 | |
| 3350 | # On MPS device we ensure the eval metrics are computed on CPU to avoid numerical errors |
| 3351 | metrics = self._ensure_metrics_device_and_dtype( |
| 3352 | metrics, |
| 3353 | ensure_cpu=isinstance(self.state.device, DeviceMPS), |
| 3354 | ) |
| 3355 | |
| 3356 | for metric in metrics.values(): |
| 3357 | metric.reset() |
| 3358 | |
| 3359 | dataloader = self.state.dataloader |
| 3360 | drop_last = None |
| 3361 | dataset_len = None |
| 3362 | last_batch = False |
| 3363 | first_eval_batch_complete = False |
| 3364 | dist_sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None |
| 3365 | if isinstance(dist_sampler, DistributedSampler) and isinstance(dataloader, DataLoader): |
| 3366 | # The distributed sampler uses `set_epoch` to set the random seed |
| 3367 | # Because evaluation can run on each batch, we use the batch to seed the sampler |
| 3368 | # so each evaluation will get a proper shuffle. |
| 3369 | # The epoch provided to `set_epoch` need not be sequential, so this is fine. |
| 3370 | dist_sampler.set_epoch(int(self.state.timestamp.batch)) |
| 3371 | drop_last = dataloader.drop_last |
no test coverage detected