MCPcopy
hub / github.com/mosaicml/composer / _eval_loop

Method _eval_loop

composer/trainer/trainer.py:3314–3566  ·  view source on GitHub ↗

Evaluate the model and log appropriate metrics. Args: evaluator (Evaluator): The evaluator to use for evaluation. metrics (dict[str, Metric]): Dictionary mapping metric names to metrics to evaluate against. subset_num_batches (int, optional): If specified

(
        self,
        evaluator: Evaluator,
        metrics: dict[str, Metric],
        subset_num_batches: Optional[int] = None,
    )

Source from the content-addressed store, hash-verified

3312 self.engine.run_event(Event.EVAL_STANDALONE_END)
3313
3314 def _eval_loop(
3315 self,
3316 evaluator: Evaluator,
3317 metrics: dict[str, Metric],
3318 subset_num_batches: Optional[int] = None,
3319 ):
3320 """Evaluate the model and log appropriate metrics.
3321
3322 Args:
3323 evaluator (Evaluator): The evaluator to use for evaluation.
3324 metrics (dict[str, Metric]): Dictionary mapping metric names to metrics to evaluate against.
3325 subset_num_batches (int, optional): If specified, evaluate on this many batches. Defaults to ``-1``,
3326 which means to iterate over the entire dataloader.
3327 """
3328 if subset_num_batches is None:
3329 subset_num_batches = -1
3330
3331 # back up the original dataloader on the state, so we can restore it after evaluation is finished
3332 original_dataloader = self.state.dataloader
3333 original_dataloader_label = self.state.dataloader_label
3334 original_num_batches = self.state.dataloader_len
3335
3336 # Unpack data_spec
3337 data_spec = evaluator.dataloader
3338
3339 # Reset the eval timestamp
3340 self.state.eval_timestamp = Timestamp()
3341
3342 last_wct = datetime.datetime.now()
3343
3344 with torch.no_grad(), model_eval_mode(self.state.model):
3345 self.state.set_dataloader(data_spec.dataloader, evaluator.label, subset_num_batches)
3346 assert self.state.dataloader is not None, 'dataloader is set'
3347
3348 self.engine.run_event(Event.EVAL_START)
3349
3350 # On MPS device we ensure the eval metrics are computed on CPU to avoid numerical errors
3351 metrics = self._ensure_metrics_device_and_dtype(
3352 metrics,
3353 ensure_cpu=isinstance(self.state.device, DeviceMPS),
3354 )
3355
3356 for metric in metrics.values():
3357 metric.reset()
3358
3359 dataloader = self.state.dataloader
3360 drop_last = None
3361 dataset_len = None
3362 last_batch = False
3363 first_eval_batch_complete = False
3364 dist_sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
3365 if isinstance(dist_sampler, DistributedSampler) and isinstance(dataloader, DataLoader):
3366 # The distributed sampler uses `set_epoch` to set the random seed
3367 # Because evaluation can run on each batch, we use the batch to seed the sampler
3368 # so each evaluation will get a proper shuffle.
3369 # The epoch provided to `set_epoch` need not be sequential, so this is fine.
3370 dist_sampler.set_epoch(int(self.state.timestamp.batch))
3371 drop_last = dataloader.drop_last

Callers 2

_run_evaluatorsMethod · 0.95
evalMethod · 0.95

Calls 15

_iter_dataloaderMethod · 0.95
TimestampClass · 0.90
model_eval_modeFunction · 0.90
get_precision_contextFunction · 0.90
DeviceCPUClass · 0.90
generate_oom_hookFunction · 0.90
_get_distributed_samplerFunction · 0.85
_is_cuda_oomFunction · 0.85

Tested by

no test coverage detected