MCPcopy Index your code
hub / github.com/modelscope/FunASR / forward

Method forward

funasr/models/branchformer/encoder.py:146–287  ·  view source on GitHub ↗

Compute encoded features. Args: x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. - w/o pos emb: Tensor (#batch, time, size). mask (torch.Tenso

(self, x_input, mask, cache=None)

Source from the content-addressed store, hash-verified

144 self.merge_proj = torch.nn.Identity()
145
146 def forward(self, x_input, mask, cache=None):
147 """Compute encoded features.
148
149 Args:
150 x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
151 - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
152 - w/o pos emb: Tensor (#batch, time, size).
153 mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
154 cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
155
156 Returns:
157 torch.Tensor: Output tensor (#batch, time, size).
158 torch.Tensor: Mask tensor (#batch, time).
159 """
160
161 if cache is not None:
162 raise NotImplementedError("cache is not None, which is not tested")
163
164 if isinstance(x_input, tuple):
165 x, pos_emb = x_input[0], x_input[1]
166 else:
167 x, pos_emb = x_input, None
168
169 skip_layer = False
170 # with stochastic depth, residual connection `x + f(x)` becomes
171 # `x <- x + 1 / (1 - p) * f(x)` at training time.
172 stoch_layer_coeff = 1.0
173 if self.training and self.stochastic_depth_rate > 0:
174 skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
175 stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
176
177 if skip_layer:
178 if cache is not None:
179 x = torch.cat([cache, x], dim=1)
180 if pos_emb is not None:
181 return (x, pos_emb), mask
182 return x, mask
183
184 # Two branches
185 x1 = x
186 x2 = x
187
188 # Branch 1: multi-headed attention module
189 if self.attn is not None:
190 x1 = self.norm_mha(x1)
191
192 if isinstance(self.attn, FastSelfAttention):
193 x_att = self.attn(x1, mask)
194 else:
195 if pos_emb is not None:
196 x_att = self.attn(x1, x1, x1, pos_emb, mask)
197 else:
198 x_att = self.attn(x1, x1, x1, mask)
199
200 x1 = self.dropout(x_att)
201
202 # Branch 2: convolutional gating mlp
203 if self.cgmlp is not None:

Callers

nothing calls this directly

Calls 1

softmaxMethod · 0.45

Tested by

no test coverage detected