MCPcopy
hub / github.com/OpenBMB/VoxCPM / validate

Function validate

scripts/train_voxcpm_finetune.py:365–469  ·  view source on GitHub ↗

Validate and generate sample audio

(
    model,
    val_loader,
    batch_processor,
    accelerator,
    tracker,
    lambdas,
    writer=None,
    step=0,
    val_ds=None,
    audio_vae=None,
    sample_rate=22050,
    out_sample_rate=0,
    val_texts=None,
    tokenizer=None,
    valid_interval=1000,
)

Source from the content-addressed store, hash-verified

363
364
365def validate(
366 model,
367 val_loader,
368 batch_processor,
369 accelerator,
370 tracker,
371 lambdas,
372 writer=None,
373 step=0,
374 val_ds=None,
375 audio_vae=None,
376 sample_rate=22050,
377 out_sample_rate=0,
378 val_texts=None,
379 tokenizer=None,
380 valid_interval=1000,
381):
382 """Validate and generate sample audio"""
383 import numpy as np # noqa: F401
384 from collections import defaultdict
385
386 model.eval()
387 total_losses = []
388 sub_losses = defaultdict(list) # Track individual sub-losses
389 num_batches = 0
390 max_val_batches = 10
391
392 with torch.no_grad():
393 for batch in val_loader:
394 if num_batches >= max_val_batches:
395 break
396 processed = batch_processor(batch)
397 with accelerator.autocast(dtype=torch.bfloat16):
398 outputs = model(
399 processed["text_tokens"],
400 processed["text_mask"],
401 processed["audio_feats"],
402 processed["audio_mask"],
403 processed["loss_mask"],
404 processed["position_ids"],
405 processed["labels"],
406 progress=0.0,
407 sample_generate=False,
408 )
409 total = 0.0
410 for key, value in outputs.items():
411 if key.startswith("loss/"):
412 weighted_loss = lambdas.get(key, 1.0) * value
413 total += weighted_loss
414 sub_losses[key].append(value.detach())
415 total_losses.append(total.detach())
416 num_batches += 1
417
418 if total_losses:
419 # Compute mean total loss
420 mean_total_loss = torch.stack(total_losses).mean()
421 accelerator.all_reduce(mean_total_loss)
422

Callers 1

trainFunction · 0.85

Calls 5

generate_sample_audioFunction · 0.85
autocastMethod · 0.80
all_reduceMethod · 0.80
log_metricsMethod · 0.80
printMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…