MCPcopy
hub / github.com/descriptinc/descript-audio-codec / __init__

Method __init__

dac/model/discriminator.py:102–147  ·  view source on GitHub ↗

Complex multi-band spectrogram discriminator. Parameters ---------- window_length : int Window length of STFT. hop_factor : float, optional Hop factor of the STFT, defaults to ``0.25 * window_length``. sample_rate : int, optional

(
        self,
        window_length: int,
        hop_factor: float = 0.25,
        sample_rate: int = 44100,
        bands: list = BANDS,
    )

Source from the content-addressed store, hash-verified

100
101class MRD(nn.Module):
102 def __init__(
103 self,
104 window_length: int,
105 hop_factor: float = 0.25,
106 sample_rate: int = 44100,
107 bands: list = BANDS,
108 ):
109 """Complex multi-band spectrogram discriminator.
110 Parameters
111 ----------
112 window_length : int
113 Window length of STFT.
114 hop_factor : float, optional
115 Hop factor of the STFT, defaults to ``0.25 * window_length``.
116 sample_rate : int, optional
117 Sampling rate of audio in Hz, by default 44100
118 bands : list, optional
119 Bands to run discriminator over.
120 """
121 super().__init__()
122
123 self.window_length = window_length
124 self.hop_factor = hop_factor
125 self.sample_rate = sample_rate
126 self.stft_params = STFTParams(
127 window_length=window_length,
128 hop_length=int(window_length * hop_factor),
129 match_stride=True,
130 )
131
132 n_fft = window_length // 2 + 1
133 bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134 self.bands = bands
135
136 ch = 32
137 convs = lambda: nn.ModuleList(
138 [
139 WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140 WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141 WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142 WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143 WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144 ]
145 )
146 self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147 self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148
149 def spectrogram(self, x):
150 x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)

Callers

nothing calls this directly

Calls 2

WNConv2dFunction · 0.85
__init__Method · 0.45

Tested by

no test coverage detected