Select 2 fixed validation samples, generate audio and log to TensorBoard
(
model,
val_ds,
audio_vae,
writer,
step,
accelerator,
sample_rate=22050,
out_sample_rate=0,
val_texts=None,
tokenizer=None,
pretrained_path=None,
valid_interval=1000,
tracker=None,
)
| 536 | |
| 537 | |
| 538 | def generate_sample_audio( |
| 539 | model, |
| 540 | val_ds, |
| 541 | audio_vae, |
| 542 | writer, |
| 543 | step, |
| 544 | accelerator, |
| 545 | sample_rate=22050, |
| 546 | out_sample_rate=0, |
| 547 | val_texts=None, |
| 548 | tokenizer=None, |
| 549 | pretrained_path=None, |
| 550 | valid_interval=1000, |
| 551 | tracker=None, |
| 552 | ): |
| 553 | """Select 2 fixed validation samples, generate audio and log to TensorBoard""" |
| 554 | import numpy as np |
| 555 | |
| 556 | log = tracker.print if tracker else print |
| 557 | num_samples = min(2, len(val_ds)) |
| 558 | log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}") |
| 559 | |
| 560 | unwrapped_model = accelerator.unwrap(model) |
| 561 | # Determine the correct output sample rate for generated audio. |
| 562 | # out_sample_rate is the decoder output rate (e.g. 48kHz for V2); |
| 563 | # sample_rate is the encoder input rate (e.g. 16kHz for V2). |
| 564 | gen_sr = out_sample_rate if out_sample_rate > 0 else sample_rate |
| 565 | |
| 566 | for i in range(num_samples): |
| 567 | sample = val_ds[i] |
| 568 | text = val_texts[i] if val_texts and i < len(val_texts) else "Hello, this is a test." |
| 569 | |
| 570 | # Load reference audio |
| 571 | ref_audio_np = None |
| 572 | try: |
| 573 | if "audio" in sample and isinstance(sample["audio"], dict) and "array" in sample["audio"]: |
| 574 | ref_audio_np = np.array(sample["audio"]["array"], dtype=np.float32) |
| 575 | ref_sr = sample["audio"].get("sampling_rate", sample_rate) |
| 576 | if ref_sr != sample_rate: |
| 577 | import torchaudio.functional as F |
| 578 | |
| 579 | ref_audio_np = ( |
| 580 | F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy() |
| 581 | ) |
| 582 | log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s") |
| 583 | except Exception as e: |
| 584 | log(f"[Warning] Failed to load reference audio: {e}") |
| 585 | |
| 586 | # Preserve the original mode so validation failures do not leak into training. |
| 587 | prev_training = unwrapped_model.training |
| 588 | try: |
| 589 | # Inference setup |
| 590 | unwrapped_model.eval() |
| 591 | # unwrapped_model.to(torch.bfloat16) |
| 592 | unwrapped_model.audio_vae = audio_vae.to(torch.float32) |
| 593 | |
| 594 | log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'") |
| 595 | autocast_ctx = ( |
no test coverage detected
searching dependent graphs…