MCPcopy
hub / github.com/Audio-AGI/AudioSep / __call__

Method __call__

evaluation/evaluate_vggsound.py:48–114  ·  view source on GitHub ↗

r"""Evalute.

(
        self,
        pl_model: pl.LightningModule
    )

Source from the content-addressed store, hash-verified

46 self.audio_dir = 'evaluation/data/vggsound'
47
48 def __call__(
49 self,
50 pl_model: pl.LightningModule
51 ) -> Dict:
52 r"""Evalute."""
53
54 print(f'Evaluation on VGGSound+ with [text label] queries.')
55
56 pl_model.eval()
57 device = pl_model.device
58
59 sisdrs_list = []
60 sdris_list = []
61 sisdris_list = []
62
63
64 with torch.no_grad():
65 for eval_data in tqdm(self.eval_list):
66
67 # labels, source_path, mixture_path = eval_data
68 file_id, mix_wav, s0_wav, s0_text, s1_wav, s1_text = eval_data
69
70 labels = s0_text
71
72 mixture_path = os.path.join(self.audio_dir, mix_wav)
73 source_path = os.path.join(self.audio_dir, s0_wav)
74
75
76 source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
77 mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
78
79 sdr_no_sep = calculate_sdr(ref=source, est=mixture)
80
81 text = [labels]
82 conditions = pl_model.query_encoder.get_query_embed(
83 modality='text',
84 text=text,
85 device=device
86 )
87
88 input_dict = {
89 "mixture": torch.Tensor(mixture)[None, None, :].to(device),
90 "condition": conditions,
91 }
92
93 sep_segment = pl_model.ss_model(input_dict)["waveform"]
94 # sep_segment: (batch_size=1, channels_num=1, segment_samples)
95
96 sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
97 # sep_segment: (segment_samples,)
98
99 sdr = calculate_sdr(ref=source, est=sep_segment)
100 sdri = sdr - sdr_no_sep
101
102 sisdr_no_sep = calculate_sisdr(ref=source, est=mixture)
103 sisdr = calculate_sisdr(ref=source, est=sep_segment)
104 sisdri = sisdr - sisdr_no_sep
105

Callers

nothing calls this directly

Calls 4

calculate_sdrFunction · 0.90
calculate_sisdrFunction · 0.90
get_query_embedMethod · 0.80
appendMethod · 0.80

Tested by

no test coverage detected