Run validation on the validation dataset.
(
policy: PolicyInterface,
val_dataloader: Optional[StatefulDataLoader],
tokenizer,
loss_fn,
step: int,
master_config: MasterConfig,
val_batches: int,
val_batch_size: int,
val_mbs: int,
)
| 232 | # Training & Validation |
| 233 | # ======================================================= |
| 234 | def validate( |
| 235 | policy: PolicyInterface, |
| 236 | val_dataloader: Optional[StatefulDataLoader], |
| 237 | tokenizer, |
| 238 | loss_fn, |
| 239 | step: int, |
| 240 | master_config: MasterConfig, |
| 241 | val_batches: int, |
| 242 | val_batch_size: int, |
| 243 | val_mbs: int, |
| 244 | ): |
| 245 | """Run validation on the validation dataset.""" |
| 246 | if val_dataloader is None: |
| 247 | assert master_config["sft"]["val_period"] <= 0, ( |
| 248 | "val_dataloader is None, so sft.val_period must be <= 0" |
| 249 | ) |
| 250 | print(" ⚠️ No validation dataloader provided, skipping validation") |
| 251 | return {}, {} |
| 252 | |
| 253 | timer = Timer() |
| 254 | |
| 255 | with timer.time("total_validation_time"): |
| 256 | print(f"▶ Starting validation at step {step}...") |
| 257 | |
| 258 | # Show a progress indicator for validation |
| 259 | # val_total = len(val_dataloader) |
| 260 | |
| 261 | val_metrics = {"val_loss": 0.0} |
| 262 | sum_num_valid_tokens = 0 |
| 263 | |
| 264 | policy.prepare_for_training() |
| 265 | for batch_idx, val_batch in enumerate(val_dataloader): |
| 266 | ## add loss mask based on role to every message |
| 267 | add_loss_mask_to_message_log( |
| 268 | val_batch["message_log"], |
| 269 | roles_to_train_on=["assistant"], |
| 270 | ) |
| 271 | |
| 272 | cat_and_padded, input_lengths = batched_message_log_to_flat_message( |
| 273 | val_batch["message_log"], |
| 274 | pad_value_dict={"token_ids": tokenizer.pad_token_id}, |
| 275 | make_sequence_length_divisible_by=master_config["policy"][ |
| 276 | "make_sequence_length_divisible_by" |
| 277 | ], |
| 278 | ) |
| 279 | |
| 280 | val_data: BatchedDataDict = BatchedDataDict( |
| 281 | { |
| 282 | "input_ids": cat_and_padded["token_ids"], |
| 283 | "input_lengths": input_lengths, |
| 284 | "token_mask": cat_and_padded["token_loss_mask"], |
| 285 | "sample_mask": val_batch["loss_multiplier"], |
| 286 | } |
| 287 | ) |
| 288 | |
| 289 | # update multimodal data |
| 290 | val_data.update(cat_and_padded.get_multimodal_dict(as_tensors=False)) |
| 291 | # When running validation with drop_last=False, we might end up with a partial batch. |
no test coverage detected