(
self,
d_model: int = 64,
strides: list = [2, 4, 8, 8],
d_latent: int = 64,
)
| 63 | |
| 64 | class Encoder(nn.Module): |
| 65 | def __init__( |
| 66 | self, |
| 67 | d_model: int = 64, |
| 68 | strides: list = [2, 4, 8, 8], |
| 69 | d_latent: int = 64, |
| 70 | ): |
| 71 | super().__init__() |
| 72 | # Create first convolution |
| 73 | self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] |
| 74 | |
| 75 | # Create EncoderBlocks that double channels as they downsample by `stride` |
| 76 | for stride in strides: |
| 77 | d_model *= 2 |
| 78 | self.block += [EncoderBlock(d_model, stride=stride)] |
| 79 | |
| 80 | # Create last convolution |
| 81 | self.block += [ |
| 82 | Snake1d(d_model), |
| 83 | WNConv1d(d_model, d_latent, kernel_size=3, padding=1), |
| 84 | ] |
| 85 | |
| 86 | # Wrap black into nn.Sequential |
| 87 | self.block = nn.Sequential(*self.block) |
| 88 | self.enc_dim = d_model |
| 89 | |
| 90 | def forward(self, x): |
| 91 | return self.block(x) |
nothing calls this directly
no test coverage detected