Complex multi-band spectrogram discriminator. Parameters ---------- window_length : int Window length of STFT. hop_factor : float, optional Hop factor of the STFT, defaults to ``0.25 * window_length``. sample_rate : int, optional
(
self,
window_length: int,
hop_factor: float = 0.25,
sample_rate: int = 44100,
bands: list = BANDS,
)
| 100 | |
| 101 | class MRD(nn.Module): |
| 102 | def __init__( |
| 103 | self, |
| 104 | window_length: int, |
| 105 | hop_factor: float = 0.25, |
| 106 | sample_rate: int = 44100, |
| 107 | bands: list = BANDS, |
| 108 | ): |
| 109 | """Complex multi-band spectrogram discriminator. |
| 110 | Parameters |
| 111 | ---------- |
| 112 | window_length : int |
| 113 | Window length of STFT. |
| 114 | hop_factor : float, optional |
| 115 | Hop factor of the STFT, defaults to ``0.25 * window_length``. |
| 116 | sample_rate : int, optional |
| 117 | Sampling rate of audio in Hz, by default 44100 |
| 118 | bands : list, optional |
| 119 | Bands to run discriminator over. |
| 120 | """ |
| 121 | super().__init__() |
| 122 | |
| 123 | self.window_length = window_length |
| 124 | self.hop_factor = hop_factor |
| 125 | self.sample_rate = sample_rate |
| 126 | self.stft_params = STFTParams( |
| 127 | window_length=window_length, |
| 128 | hop_length=int(window_length * hop_factor), |
| 129 | match_stride=True, |
| 130 | ) |
| 131 | |
| 132 | n_fft = window_length // 2 + 1 |
| 133 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] |
| 134 | self.bands = bands |
| 135 | |
| 136 | ch = 32 |
| 137 | convs = lambda: nn.ModuleList( |
| 138 | [ |
| 139 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), |
| 140 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), |
| 141 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), |
| 142 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), |
| 143 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), |
| 144 | ] |
| 145 | ) |
| 146 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) |
| 147 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) |
| 148 | |
| 149 | def spectrogram(self, x): |
| 150 | x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) |