(
trainer,
val_dls,
writer,
logger,
step_count,
update_panel: bool = False,
streaming: bool = False,
)
| 61 | |
| 62 | |
| 63 | def evaluate_on_val_dls( |
| 64 | trainer, |
| 65 | val_dls, |
| 66 | writer, |
| 67 | logger, |
| 68 | step_count, |
| 69 | update_panel: bool = False, |
| 70 | streaming: bool = False, |
| 71 | ): |
| 72 | with switch_sequence_parallel_mode(): |
| 73 | torch.cuda.empty_cache() |
| 74 | trainer.eval() |
| 75 | verbose = gpc.is_rank_for_log() |
| 76 | data_cfg = gpc.config.data |
| 77 | |
| 78 | for val_name, val_dl in val_dls.items(): |
| 79 | if len(val_dl) == 0 and verbose and not streaming: |
| 80 | logger.info(f"Validation dataset: {val_name} is empty") |
| 81 | continue |
| 82 | |
| 83 | val_metric = AccPerplex( |
| 84 | device=torch.cuda.current_device(), |
| 85 | tp_pg=gpc.get_group(ParallelMode.TENSOR), |
| 86 | dp_pg=gpc.get_group(ParallelMode.DATA), |
| 87 | ) |
| 88 | val_sche_metric_hook = SchedulerMetricHook(metric=val_metric) |
| 89 | |
| 90 | val_loss = 0 |
| 91 | val_idx = -1 |
| 92 | for val_idx, batch in tqdm( |
| 93 | enumerate(val_dl), |
| 94 | desc="Val.", |
| 95 | total=len(val_dl) if not streaming else None, |
| 96 | position=1, |
| 97 | disable=not verbose, |
| 98 | leave=False, |
| 99 | ): |
| 100 | with torch.inference_mode(): |
| 101 | if gpc.is_using_pp(): |
| 102 | total_val_bsz = len(batch[1]) |
| 103 | assert total_val_bsz % data_cfg.micro_bsz == 0 |
| 104 | num_microbatches = total_val_bsz // data_cfg.micro_bsz |
| 105 | tensor_shape = torch.Size( |
| 106 | [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] |
| 107 | ) |
| 108 | |
| 109 | with switch_evaluation_pipeline_scheduler( |
| 110 | trainer=trainer, |
| 111 | num_microbatches=num_microbatches, |
| 112 | tensor_shape=tensor_shape, |
| 113 | metric_hook_list=[val_sche_metric_hook], |
| 114 | ): |
| 115 | _, _, loss = trainer.execute_schedule( |
| 116 | batch, forward_only=True, return_loss=True, return_output_label=False |
| 117 | ) |
| 118 | else: |
| 119 | total_val_bsz = len(batch[1]) |
| 120 | assert total_val_bsz % data_cfg.micro_bsz == 0 |
no test coverage detected