MCPcopy Index your code
hub / github.com/NVIDIA-NeMo/RL / validate

Function validate

nemo_rl/algorithms/sft.py:234–349  ·  view source on GitHub ↗

Run validation on the validation dataset.

(
    policy: PolicyInterface,
    val_dataloader: Optional[StatefulDataLoader],
    tokenizer,
    loss_fn,
    step: int,
    master_config: MasterConfig,
    val_batches: int,
    val_batch_size: int,
    val_mbs: int,
)

Source from the content-addressed store, hash-verified

232# Training & Validation
233# =======================================================
234def validate(
235 policy: PolicyInterface,
236 val_dataloader: Optional[StatefulDataLoader],
237 tokenizer,
238 loss_fn,
239 step: int,
240 master_config: MasterConfig,
241 val_batches: int,
242 val_batch_size: int,
243 val_mbs: int,
244):
245 """Run validation on the validation dataset."""
246 if val_dataloader is None:
247 assert master_config["sft"]["val_period"] <= 0, (
248 "val_dataloader is None, so sft.val_period must be <= 0"
249 )
250 print(" ⚠️ No validation dataloader provided, skipping validation")
251 return {}, {}
252
253 timer = Timer()
254
255 with timer.time("total_validation_time"):
256 print(f"▶ Starting validation at step {step}...")
257
258 # Show a progress indicator for validation
259 # val_total = len(val_dataloader)
260
261 val_metrics = {"val_loss": 0.0}
262 sum_num_valid_tokens = 0
263
264 policy.prepare_for_training()
265 for batch_idx, val_batch in enumerate(val_dataloader):
266 ## add loss mask based on role to every message
267 add_loss_mask_to_message_log(
268 val_batch["message_log"],
269 roles_to_train_on=["assistant"],
270 )
271
272 cat_and_padded, input_lengths = batched_message_log_to_flat_message(
273 val_batch["message_log"],
274 pad_value_dict={"token_ids": tokenizer.pad_token_id},
275 make_sequence_length_divisible_by=master_config["policy"][
276 "make_sequence_length_divisible_by"
277 ],
278 )
279
280 val_data: BatchedDataDict = BatchedDataDict(
281 {
282 "input_ids": cat_and_padded["token_ids"],
283 "input_lengths": input_lengths,
284 "token_mask": cat_and_padded["token_loss_mask"],
285 "sample_mask": val_batch["loss_multiplier"],
286 }
287 )
288
289 # update multimodal data
290 val_data.update(cat_and_padded.get_multimodal_dict(as_tensors=False))
291 # When running validation with drop_last=False, we might end up with a partial batch.

Callers 1

sft_trainFunction · 0.70

Calls 13

timeMethod · 0.95
get_timing_metricsMethod · 0.95
resetMethod · 0.95
TimerClass · 0.90
BatchedDataDictClass · 0.90
maybe_pad_last_batchFunction · 0.90
updateMethod · 0.80
get_multimodal_dictMethod · 0.80
get_axis_sizeMethod · 0.80
prepare_for_trainingMethod · 0.45

Tested by

no test coverage detected