MCPcopy Index your code
hub / github.com/modelscope/FunASR / forward

Method forward

funasr/models/scama/model.py:200–275  ·  view source on GitHub ↗

Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,)

(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    )

Source from the content-addressed store, hash-verified

198 self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
199
200 def forward(
201 self,
202 speech: torch.Tensor,
203 speech_lengths: torch.Tensor,
204 text: torch.Tensor,
205 text_lengths: torch.Tensor,
206 **kwargs,
207 ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
208 """Encoder + Decoder + Calc loss
209 Args:
210 speech: (Batch, Length, ...)
211 speech_lengths: (Batch, )
212 text: (Batch, Length)
213 text_lengths: (Batch,)
214 """
215
216 decoding_ind = kwargs.get("decoding_ind")
217 if len(text_lengths.size()) > 1:
218 text_lengths = text_lengths[:, 0]
219 if len(speech_lengths.size()) > 1:
220 speech_lengths = speech_lengths[:, 0]
221
222 batch_size = speech.shape[0]
223
224 # Encoder
225 ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
226 encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
227
228 loss_ctc, cer_ctc = None, None
229 loss_pre = None
230 stats = dict()
231
232 # decoder: CTC branch
233
234 if self.ctc_weight > 0.0:
235
236 encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(
237 encoder_out, encoder_out_lens, chunk_outs=None
238 )
239
240 loss_ctc, cer_ctc = self._calc_ctc_loss(
241 encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
242 )
243 # Collect CTC branch stats
244 stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
245 stats["cer_ctc"] = cer_ctc
246
247 # decoder: Attention decoder branch
248 loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
249 encoder_out, encoder_out_lens, text, text_lengths
250 )
251
252 # 3. CTC-Att loss definition
253 if self.ctc_weight == 0.0:
254 loss = loss_att + loss_pre * self.predictor_weight
255 else:
256 loss = (
257 self.ctc_weight * loss_ctc

Callers

nothing calls this directly

Calls 6

encodeMethod · 0.95
force_gatherableFunction · 0.90
random_choiceMethod · 0.80
remove_chunkMethod · 0.80
_calc_ctc_lossMethod · 0.45

Tested by

no test coverage detected