MCPcopy
hub / github.com/Audio-AGI/AudioSep / __call__

Method __call__

evaluation/evaluate_clotho.py:45–102  ·  view source on GitHub ↗

r"""Evalute.

(
        self,
        pl_model: pl.LightningModule
    )

Source from the content-addressed store, hash-verified

43 self.audio_dir = 'evaluation/data/clotho'
44
45 def __call__(
46 self,
47 pl_model: pl.LightningModule
48 ) -> Dict:
49 r"""Evalute."""
50
51 print(f'Evaluation on Clotho Evaluation with [caption] queries.')
52
53 pl_model.eval()
54 device = pl_model.device
55
56 sisdrs_list = []
57 sdris_list = []
58
59 with torch.no_grad():
60 for eval_data in tqdm(self.eval_list):
61
62 idx, caption, _, _, _ = eval_data
63
64 source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav')
65 mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav')
66
67 source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
68 mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
69
70 sdr_no_sep = calculate_sdr(ref=source, est=mixture)
71
72 text = [caption]
73
74 conditions = pl_model.query_encoder.get_query_embed(
75 modality='text',
76 text=text,
77 device=device
78 )
79
80 input_dict = {
81 "mixture": torch.Tensor(mixture)[None, None, :].to(device),
82 "condition": conditions,
83 }
84
85 sep_segment = pl_model.ss_model(input_dict)["waveform"]
86 # sep_segment: (batch_size=1, channels_num=1, segment_samples)
87
88 sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
89 # sep_segment: (segment_samples,)
90
91 sdr = calculate_sdr(ref=source, est=sep_segment)
92 sdri = sdr - sdr_no_sep
93 sisdr = calculate_sisdr(ref=source, est=sep_segment)
94
95
96 sisdrs_list.append(sisdr)
97 sdris_list.append(sdri)
98
99 mean_sisdr = np.mean(sisdrs_list)
100 mean_sdri = np.mean(sdris_list)
101
102 return mean_sisdr, mean_sdri

Callers

nothing calls this directly

Calls 4

calculate_sdrFunction · 0.90
calculate_sisdrFunction · 0.90
get_query_embedMethod · 0.80
appendMethod · 0.80

Tested by

no test coverage detected