| 41 | |
| 42 | |
| 43 | class EncoderBlock(nn.Module): |
| 44 | def __init__(self, dim: int = 16, stride: int = 1): |
| 45 | super().__init__() |
| 46 | self.block = nn.Sequential( |
| 47 | ResidualUnit(dim // 2, dilation=1), |
| 48 | ResidualUnit(dim // 2, dilation=3), |
| 49 | ResidualUnit(dim // 2, dilation=9), |
| 50 | Snake1d(dim // 2), |
| 51 | WNConv1d( |
| 52 | dim // 2, |
| 53 | dim, |
| 54 | kernel_size=2 * stride, |
| 55 | stride=stride, |
| 56 | padding=math.ceil(stride / 2), |
| 57 | ), |
| 58 | ) |
| 59 | |
| 60 | def forward(self, x): |
| 61 | return self.block(x) |
| 62 | |
| 63 | |
| 64 | class Encoder(nn.Module): |