| 153 | return losses, output |
| 154 | |
| 155 | def validation_step(self, sample, batch_idx): |
| 156 | outputs = {} |
| 157 | txt_tokens = sample['txt_tokens'] # [B, T_t] |
| 158 | |
| 159 | target = sample['mels'] # [B, T_s, 80] |
| 160 | energy = sample['energy'] |
| 161 | # fs2_mel = sample['fs2_mels'] |
| 162 | spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') |
| 163 | mel2ph = sample['mel2ph'] |
| 164 | f0 = sample['f0'] |
| 165 | uv = sample['uv'] |
| 166 | |
| 167 | outputs['losses'] = {} |
| 168 | |
| 169 | outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) |
| 170 | |
| 171 | |
| 172 | outputs['total_loss'] = sum(outputs['losses'].values()) |
| 173 | outputs['nsamples'] = sample['nsamples'] |
| 174 | outputs = utils.tensors_to_scalars(outputs) |
| 175 | if batch_idx < hparams['num_valid_plots']: |
| 176 | fs2_mel = sample['fs2_mels'] |
| 177 | model_out = self.model( |
| 178 | txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, |
| 179 | ref_mels=[None, fs2_mel], infer=True) |
| 180 | if hparams.get('pe_enable') is not None and hparams['pe_enable']: |
| 181 | gt_f0 = self.pe(sample['mels'])['f0_denorm_pred'] # pe predict from GT mel |
| 182 | pred_f0 = self.pe(model_out['mel_out'])['f0_denorm_pred'] # pe predict from Pred mel |
| 183 | else: |
| 184 | gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) |
| 185 | pred_f0 = model_out.get('f0_denorm') |
| 186 | self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=pred_f0) |
| 187 | self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'], name=f'diffmel_{batch_idx}') |
| 188 | self.plot_mel(batch_idx, sample['mels'], fs2_mel, name=f'fs2mel_{batch_idx}') |
| 189 | return outputs |
| 190 | |
| 191 | def test_step(self, sample, batch_idx): |
| 192 | spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') |