(self, sample, batch_idx)
| 115 | return total_loss, loss_output |
| 116 | |
| 117 | def validation_step(self, sample, batch_idx): |
| 118 | outputs = {} |
| 119 | outputs['losses'] = {} |
| 120 | outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=True) |
| 121 | outputs['total_loss'] = sum(outputs['losses'].values()) |
| 122 | outputs['nsamples'] = sample['nsamples'] |
| 123 | outputs = utils.tensors_to_scalars(outputs) |
| 124 | if batch_idx < hparams['num_valid_plots']: |
| 125 | self.plot_pitch(batch_idx, model_out, sample) |
| 126 | return outputs |
| 127 | |
| 128 | def run_model(self, model, sample, return_output=False, infer=False): |
| 129 | f0 = sample['f0'] |
nothing calls this directly
no test coverage detected