Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,)
(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
)
| 198 | self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk") |
| 199 | |
| 200 | def forward( |
| 201 | self, |
| 202 | speech: torch.Tensor, |
| 203 | speech_lengths: torch.Tensor, |
| 204 | text: torch.Tensor, |
| 205 | text_lengths: torch.Tensor, |
| 206 | **kwargs, |
| 207 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| 208 | """Encoder + Decoder + Calc loss |
| 209 | Args: |
| 210 | speech: (Batch, Length, ...) |
| 211 | speech_lengths: (Batch, ) |
| 212 | text: (Batch, Length) |
| 213 | text_lengths: (Batch,) |
| 214 | """ |
| 215 | |
| 216 | decoding_ind = kwargs.get("decoding_ind") |
| 217 | if len(text_lengths.size()) > 1: |
| 218 | text_lengths = text_lengths[:, 0] |
| 219 | if len(speech_lengths.size()) > 1: |
| 220 | speech_lengths = speech_lengths[:, 0] |
| 221 | |
| 222 | batch_size = speech.shape[0] |
| 223 | |
| 224 | # Encoder |
| 225 | ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind) |
| 226 | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind) |
| 227 | |
| 228 | loss_ctc, cer_ctc = None, None |
| 229 | loss_pre = None |
| 230 | stats = dict() |
| 231 | |
| 232 | # decoder: CTC branch |
| 233 | |
| 234 | if self.ctc_weight > 0.0: |
| 235 | |
| 236 | encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk( |
| 237 | encoder_out, encoder_out_lens, chunk_outs=None |
| 238 | ) |
| 239 | |
| 240 | loss_ctc, cer_ctc = self._calc_ctc_loss( |
| 241 | encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths |
| 242 | ) |
| 243 | # Collect CTC branch stats |
| 244 | stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None |
| 245 | stats["cer_ctc"] = cer_ctc |
| 246 | |
| 247 | # decoder: Attention decoder branch |
| 248 | loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss( |
| 249 | encoder_out, encoder_out_lens, text, text_lengths |
| 250 | ) |
| 251 | |
| 252 | # 3. CTC-Att loss definition |
| 253 | if self.ctc_weight == 0.0: |
| 254 | loss = loss_att + loss_pre * self.predictor_weight |
| 255 | else: |
| 256 | loss = ( |
| 257 | self.ctc_weight * loss_ctc |
nothing calls this directly
no test coverage detected