Validate and generate sample audio
(
model,
val_loader,
batch_processor,
accelerator,
tracker,
lambdas,
writer=None,
step=0,
val_ds=None,
audio_vae=None,
sample_rate=22050,
out_sample_rate=0,
val_texts=None,
tokenizer=None,
valid_interval=1000,
)
| 363 | |
| 364 | |
| 365 | def validate( |
| 366 | model, |
| 367 | val_loader, |
| 368 | batch_processor, |
| 369 | accelerator, |
| 370 | tracker, |
| 371 | lambdas, |
| 372 | writer=None, |
| 373 | step=0, |
| 374 | val_ds=None, |
| 375 | audio_vae=None, |
| 376 | sample_rate=22050, |
| 377 | out_sample_rate=0, |
| 378 | val_texts=None, |
| 379 | tokenizer=None, |
| 380 | valid_interval=1000, |
| 381 | ): |
| 382 | """Validate and generate sample audio""" |
| 383 | import numpy as np # noqa: F401 |
| 384 | from collections import defaultdict |
| 385 | |
| 386 | model.eval() |
| 387 | total_losses = [] |
| 388 | sub_losses = defaultdict(list) # Track individual sub-losses |
| 389 | num_batches = 0 |
| 390 | max_val_batches = 10 |
| 391 | |
| 392 | with torch.no_grad(): |
| 393 | for batch in val_loader: |
| 394 | if num_batches >= max_val_batches: |
| 395 | break |
| 396 | processed = batch_processor(batch) |
| 397 | with accelerator.autocast(dtype=torch.bfloat16): |
| 398 | outputs = model( |
| 399 | processed["text_tokens"], |
| 400 | processed["text_mask"], |
| 401 | processed["audio_feats"], |
| 402 | processed["audio_mask"], |
| 403 | processed["loss_mask"], |
| 404 | processed["position_ids"], |
| 405 | processed["labels"], |
| 406 | progress=0.0, |
| 407 | sample_generate=False, |
| 408 | ) |
| 409 | total = 0.0 |
| 410 | for key, value in outputs.items(): |
| 411 | if key.startswith("loss/"): |
| 412 | weighted_loss = lambdas.get(key, 1.0) * value |
| 413 | total += weighted_loss |
| 414 | sub_losses[key].append(value.detach()) |
| 415 | total_losses.append(total.detach()) |
| 416 | num_batches += 1 |
| 417 | |
| 418 | if total_losses: |
| 419 | # Compute mean total loss |
| 420 | mean_total_loss = torch.stack(total_losses).mean() |
| 421 | accelerator.all_reduce(mean_total_loss) |
| 422 |
no test coverage detected
searching dependent graphs…