MCPcopy Index your code
hub / github.com/modelscope/FunASR / forward

Method forward

funasr/models/xvector/e2e_sv.py:92–210  ·  view source on GitHub ↗

Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,)

(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    )

Source from the content-addressed store, hash-verified

90 self.decoder = decoder
91
92 def forward(
93 self,
94 speech: torch.Tensor,
95 speech_lengths: torch.Tensor,
96 text: torch.Tensor,
97 text_lengths: torch.Tensor,
98 ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
99 """Frontend + Encoder + Decoder + Calc loss
100 Args:
101 speech: (Batch, Length, ...)
102 speech_lengths: (Batch, )
103 text: (Batch, Length)
104 text_lengths: (Batch,)
105 """
106 assert text_lengths.dim() == 1, text_lengths.shape
107 # Check that batch_size is unified
108 assert (
109 speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
110 ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
111 batch_size = speech.shape[0]
112
113 # for data-parallel
114 text = text[:, : text_lengths.max()]
115
116 # 1. Encoder
117 encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
118 intermediate_outs = None
119 if isinstance(encoder_out, tuple):
120 intermediate_outs = encoder_out[1]
121 encoder_out = encoder_out[0]
122
123 loss_att, acc_att, cer_att, wer_att = None, None, None, None
124 loss_ctc, cer_ctc = None, None
125 loss_transducer, cer_transducer, wer_transducer = None, None, None
126 stats = dict()
127
128 # 1. CTC branch
129 if self.ctc_weight != 0.0:
130 loss_ctc, cer_ctc = self._calc_ctc_loss(
131 encoder_out, encoder_out_lens, text, text_lengths
132 )
133
134 # Collect CTC branch stats
135 stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
136 stats["cer_ctc"] = cer_ctc
137
138 # Intermediate CTC (optional)
139 loss_interctc = 0.0
140 if self.interctc_weight != 0.0 and intermediate_outs is not None:
141 for layer_idx, intermediate_out in intermediate_outs:
142 # we assume intermediate_out has the same length & padding
143 # as those of encoder_out
144 loss_ic, cer_ic = self._calc_ctc_loss(
145 intermediate_out, encoder_out_lens, text, text_lengths
146 )
147 loss_interctc = loss_interctc + loss_ic
148
149 # Collect Intermedaite CTC stats

Callers

nothing calls this directly

Calls 5

encodeMethod · 0.95
force_gatherableFunction · 0.90
_calc_transducer_lossMethod · 0.80
_calc_ctc_lossMethod · 0.45
_calc_att_lossMethod · 0.45

Tested by

no test coverage detected