r"""Evalute.
(
self,
pl_model: pl.LightningModule
)
| 43 | self.audio_dir = 'evaluation/data/clotho' |
| 44 | |
| 45 | def __call__( |
| 46 | self, |
| 47 | pl_model: pl.LightningModule |
| 48 | ) -> Dict: |
| 49 | r"""Evalute.""" |
| 50 | |
| 51 | print(f'Evaluation on Clotho Evaluation with [caption] queries.') |
| 52 | |
| 53 | pl_model.eval() |
| 54 | device = pl_model.device |
| 55 | |
| 56 | sisdrs_list = [] |
| 57 | sdris_list = [] |
| 58 | |
| 59 | with torch.no_grad(): |
| 60 | for eval_data in tqdm(self.eval_list): |
| 61 | |
| 62 | idx, caption, _, _, _ = eval_data |
| 63 | |
| 64 | source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav') |
| 65 | mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav') |
| 66 | |
| 67 | source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True) |
| 68 | mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True) |
| 69 | |
| 70 | sdr_no_sep = calculate_sdr(ref=source, est=mixture) |
| 71 | |
| 72 | text = [caption] |
| 73 | |
| 74 | conditions = pl_model.query_encoder.get_query_embed( |
| 75 | modality='text', |
| 76 | text=text, |
| 77 | device=device |
| 78 | ) |
| 79 | |
| 80 | input_dict = { |
| 81 | "mixture": torch.Tensor(mixture)[None, None, :].to(device), |
| 82 | "condition": conditions, |
| 83 | } |
| 84 | |
| 85 | sep_segment = pl_model.ss_model(input_dict)["waveform"] |
| 86 | # sep_segment: (batch_size=1, channels_num=1, segment_samples) |
| 87 | |
| 88 | sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() |
| 89 | # sep_segment: (segment_samples,) |
| 90 | |
| 91 | sdr = calculate_sdr(ref=source, est=sep_segment) |
| 92 | sdri = sdr - sdr_no_sep |
| 93 | sisdr = calculate_sisdr(ref=source, est=sep_segment) |
| 94 | |
| 95 | |
| 96 | sisdrs_list.append(sisdr) |
| 97 | sdris_list.append(sdri) |
| 98 | |
| 99 | mean_sisdr = np.mean(sisdrs_list) |
| 100 | mean_sdri = np.mean(sdris_list) |
| 101 | |
| 102 | return mean_sisdr, mean_sdri |
nothing calls this directly
no test coverage detected