r"""Calculate SDR between reference and estimation. Args: ref (np.ndarray), reference signal est (np.ndarray), estimated signal
(ref, est)
| 170 | |
| 171 | |
| 172 | def calculate_sisdr(ref, est): |
| 173 | r"""Calculate SDR between reference and estimation. |
| 174 | |
| 175 | Args: |
| 176 | ref (np.ndarray), reference signal |
| 177 | est (np.ndarray), estimated signal |
| 178 | """ |
| 179 | |
| 180 | eps = np.finfo(ref.dtype).eps |
| 181 | |
| 182 | reference = ref.copy() |
| 183 | estimate = est.copy() |
| 184 | |
| 185 | reference = reference.reshape(reference.size, 1) |
| 186 | estimate = estimate.reshape(estimate.size, 1) |
| 187 | |
| 188 | Rss = np.dot(reference.T, reference) |
| 189 | # get the scaling factor for clean sources |
| 190 | a = (eps + np.dot(reference.T, estimate)) / (Rss + eps) |
| 191 | |
| 192 | e_true = a * reference |
| 193 | e_res = estimate - e_true |
| 194 | |
| 195 | Sss = (e_true**2).sum() |
| 196 | Snn = (e_res**2).sum() |
| 197 | |
| 198 | sisdr = 10 * np.log10((eps+ Sss)/(eps + Snn)) |
| 199 | |
| 200 | return sisdr |
| 201 | |
| 202 | |
| 203 | class StatisticsContainer(object): |