| 47 | v.requires_grad = False |
| 48 | |
| 49 | def validation_step(self, sample, batch_idx): |
| 50 | outputs = {} |
| 51 | txt_tokens = sample['txt_tokens'] # [B, T_t] |
| 52 | |
| 53 | target = sample['mels'] # [B, T_s, 80] |
| 54 | energy = sample['energy'] |
| 55 | # fs2_mel = sample['fs2_mels'] |
| 56 | spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') |
| 57 | mel2ph = sample['mel2ph'] |
| 58 | f0 = sample['f0'] |
| 59 | uv = sample['uv'] |
| 60 | |
| 61 | outputs['losses'] = {} |
| 62 | |
| 63 | outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) |
| 64 | |
| 65 | |
| 66 | outputs['total_loss'] = sum(outputs['losses'].values()) |
| 67 | outputs['nsamples'] = sample['nsamples'] |
| 68 | outputs = utils.tensors_to_scalars(outputs) |
| 69 | if batch_idx < hparams['num_valid_plots']: |
| 70 | model_out = self.model( |
| 71 | txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, ref_mels=None, infer=True) |
| 72 | |
| 73 | if hparams.get('pe_enable') is not None and hparams['pe_enable']: |
| 74 | gt_f0 = self.pe(sample['mels'])['f0_denorm_pred'] # pe predict from GT mel |
| 75 | pred_f0 = self.pe(model_out['mel_out'])['f0_denorm_pred'] # pe predict from Pred mel |
| 76 | else: |
| 77 | gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) |
| 78 | pred_f0 = model_out.get('f0_denorm') |
| 79 | self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=pred_f0) |
| 80 | self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'], name=f'diffmel_{batch_idx}') |
| 81 | self.plot_mel(batch_idx, sample['mels'], model_out['fs2_mel'], name=f'fs2mel_{batch_idx}') |
| 82 | return outputs |
| 83 | |
| 84 | |
| 85 | class ShallowDiffusionOfflineDataset(FastSpeechDataset): |