Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) ind: int
(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
)
| 300 | return loss, stats, weight |
| 301 | |
| 302 | def encode( |
| 303 | self, |
| 304 | speech: torch.Tensor, |
| 305 | speech_lengths: torch.Tensor, |
| 306 | **kwargs, |
| 307 | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 308 | """Frontend + Encoder. Note that this method is used by asr_inference.py |
| 309 | Args: |
| 310 | speech: (Batch, Length, ...) |
| 311 | speech_lengths: (Batch, ) |
| 312 | ind: int |
| 313 | """ |
| 314 | with autocast(False): |
| 315 | # Data augmentation |
| 316 | if self.specaug is not None and self.training: |
| 317 | speech, speech_lengths = self.specaug(speech, speech_lengths) |
| 318 | # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| 319 | if self.normalize is not None: |
| 320 | speech, speech_lengths = self.normalize(speech, speech_lengths) |
| 321 | # Forward encoder |
| 322 | # feats: (Batch, Length, Dim) |
| 323 | # -> encoder_out: (Batch, Length2, Dim2) |
| 324 | if self.encoder.interctc_use_conditioning: |
| 325 | encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc) |
| 326 | else: |
| 327 | encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) |
| 328 | intermediate_outs = None |
| 329 | if isinstance(encoder_out, tuple): |
| 330 | intermediate_outs = encoder_out[1] |
| 331 | encoder_out = encoder_out[0] |
| 332 | |
| 333 | if intermediate_outs is not None: |
| 334 | return (encoder_out, intermediate_outs), encoder_out_lens |
| 335 | return encoder_out, encoder_out_lens |
| 336 | |
| 337 | def _calc_att_loss( |
| 338 | self, |