Compute encoded features. Args: x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. - w/o pos emb: Tensor (#batch, time, size). mask (torch.Tenso
(self, x_input, mask, cache=None)
| 144 | self.merge_proj = torch.nn.Identity() |
| 145 | |
| 146 | def forward(self, x_input, mask, cache=None): |
| 147 | """Compute encoded features. |
| 148 | |
| 149 | Args: |
| 150 | x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. |
| 151 | - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. |
| 152 | - w/o pos emb: Tensor (#batch, time, size). |
| 153 | mask (torch.Tensor): Mask tensor for the input (#batch, 1, time). |
| 154 | cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). |
| 155 | |
| 156 | Returns: |
| 157 | torch.Tensor: Output tensor (#batch, time, size). |
| 158 | torch.Tensor: Mask tensor (#batch, time). |
| 159 | """ |
| 160 | |
| 161 | if cache is not None: |
| 162 | raise NotImplementedError("cache is not None, which is not tested") |
| 163 | |
| 164 | if isinstance(x_input, tuple): |
| 165 | x, pos_emb = x_input[0], x_input[1] |
| 166 | else: |
| 167 | x, pos_emb = x_input, None |
| 168 | |
| 169 | skip_layer = False |
| 170 | # with stochastic depth, residual connection `x + f(x)` becomes |
| 171 | # `x <- x + 1 / (1 - p) * f(x)` at training time. |
| 172 | stoch_layer_coeff = 1.0 |
| 173 | if self.training and self.stochastic_depth_rate > 0: |
| 174 | skip_layer = torch.rand(1).item() < self.stochastic_depth_rate |
| 175 | stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) |
| 176 | |
| 177 | if skip_layer: |
| 178 | if cache is not None: |
| 179 | x = torch.cat([cache, x], dim=1) |
| 180 | if pos_emb is not None: |
| 181 | return (x, pos_emb), mask |
| 182 | return x, mask |
| 183 | |
| 184 | # Two branches |
| 185 | x1 = x |
| 186 | x2 = x |
| 187 | |
| 188 | # Branch 1: multi-headed attention module |
| 189 | if self.attn is not None: |
| 190 | x1 = self.norm_mha(x1) |
| 191 | |
| 192 | if isinstance(self.attn, FastSelfAttention): |
| 193 | x_att = self.attn(x1, mask) |
| 194 | else: |
| 195 | if pos_emb is not None: |
| 196 | x_att = self.attn(x1, x1, x1, pos_emb, mask) |
| 197 | else: |
| 198 | x_att = self.attn(x1, x1, x1, mask) |
| 199 | |
| 200 | x1 = self.dropout(x_att) |
| 201 | |
| 202 | # Branch 2: convolutional gating mlp |
| 203 | if self.cgmlp is not None: |