| 857 | |
| 858 | class ConvStacks(nn.Module): |
| 859 | def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn', |
| 860 | dropout=0, strides=None, res=True): |
| 861 | super().__init__() |
| 862 | self.conv = torch.nn.ModuleList() |
| 863 | self.kernel_size = kernel_size |
| 864 | self.res = res |
| 865 | self.in_proj = Linear(idim, n_chans) |
| 866 | if strides is None: |
| 867 | strides = [1] * n_layers |
| 868 | else: |
| 869 | assert len(strides) == n_layers |
| 870 | for idx in range(n_layers): |
| 871 | self.conv.append(ConvBlock( |
| 872 | n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout)) |
| 873 | self.out_proj = Linear(n_chans, odim) |
| 874 | |
| 875 | def forward(self, x, return_hiddens=False): |
| 876 | """ |