MCPcopy Index your code
hub / github.com/modelscope/FunASR / forward

Method forward

funasr/models/language_model/rnn/decoders.py:177–340  ·  view source on GitHub ↗

Decoder forward :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) [in multi-encoder case, list of torch.Tensor, [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]

(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None)

Source from the content-addressed store, hash-verified

175 return z_list, c_list
176
177 def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None):
178 """Decoder forward
179
180 :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
181 [in multi-encoder case,
182 list of torch.Tensor,
183 [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
184 :param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
185 [in multi-encoder case, list of torch.Tensor,
186 [(B), (B), ..., ]
187 :param torch.Tensor ys_pad: batch of padded character id sequence tensor
188 (B, Lmax)
189 :param int strm_idx: stream index indicates the index of decoding stream.
190 :param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
191 :return: attention loss value
192 :rtype: torch.Tensor
193 :return: accuracy
194 :rtype: float
195 """
196 # to support mutiple encoder asr mode, in single encoder mode,
197 # convert torch.Tensor to List of torch.Tensor
198 if self.num_encs == 1:
199 hs_pad = [hs_pad]
200 hlens = [hlens]
201
202 # TODO(kan-bayashi): need to make more smart way
203 ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
204 # attention index for the attention module
205 # in SPA (speaker parallel attention),
206 # att_idx is used to select attention module. In other cases, it is 0.
207 att_idx = min(strm_idx, len(self.att) - 1)
208
209 # hlens should be list of list of integer
210 hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
211
212 self.loss = None
213 # prepare input and output word sequences with sos/eos IDs
214 eos = ys[0].new([self.eos])
215 sos = ys[0].new([self.sos])
216 if self.replace_sos:
217 ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
218 else:
219 ys_in = [torch.cat([sos, y], dim=0) for y in ys]
220 ys_out = [torch.cat([y, eos], dim=0) for y in ys]
221
222 # padding for ys with -1
223 # pys: utt x olen
224 ys_in_pad = pad_list(ys_in, self.eos)
225 ys_out_pad = pad_list(ys_out, self.ignore_id)
226
227 # get dim, length info
228 batch = ys_out_pad.size(0)
229 olength = ys_out_pad.size(1)
230 for idx in range(self.num_encs):
231 logging.info(
232 self.__class__.__name__
233 + "Number of Encoder:{}; enc{}: input lengths: {}.".format(
234 self.num_encs, idx + 1, hlens[idx]

Callers

nothing calls this directly

Calls 9

zero_stateMethod · 0.95
rnn_forwardMethod · 0.95
pad_listFunction · 0.90
to_deviceFunction · 0.90
th_accuracyFunction · 0.90
resetMethod · 0.45
outputMethod · 0.45
argmaxMethod · 0.45
log_softmaxMethod · 0.45

Tested by

no test coverage detected