:param x: [B, T, H] :return: [B, T, H]
(self, x, return_hiddens=False)
| 873 | self.out_proj = Linear(n_chans, odim) |
| 874 | |
| 875 | def forward(self, x, return_hiddens=False): |
| 876 | """ |
| 877 | |
| 878 | :param x: [B, T, H] |
| 879 | :return: [B, T, H] |
| 880 | """ |
| 881 | x = self.in_proj(x) |
| 882 | x = x.transpose(1, -1) # (B, idim, Tmax) |
| 883 | hiddens = [] |
| 884 | for f in self.conv: |
| 885 | x_ = f(x) |
| 886 | x = x + x_ if self.res else x_ # (B, C, Tmax) |
| 887 | hiddens.append(x) |
| 888 | x = x.transpose(1, -1) |
| 889 | x = self.out_proj(x) # (B, Tmax, H) |
| 890 | if return_hiddens: |
| 891 | hiddens = torch.stack(hiddens, 1) # [B, L, C, T] |
| 892 | return x, hiddens |
| 893 | return x |
| 894 | |
| 895 | |
| 896 | class ConvGlobalStacks(nn.Module): |