MCPcopy
hub / github.com/modelscope/FunASR / forward

Method forward

funasr/models/transducer/rnn_decoder.py:220–311  ·  view source on GitHub ↗

Forward pass for training. Args: hs_pad: TODO. hlens: TODO. ys_in_pad: TODO. ys_in_lens: Lengths of ys_in. strm_idx: TODO.

(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0)

Source from the content-addressed store, hash-verified

218 return z_list, c_list
219
220 def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
221 # to support mutiple encoder asr mode, in single encoder mode,
222 # convert torch.Tensor to List of torch.Tensor
223 """Forward pass for training.
224
225 Args:
226 hs_pad: TODO.
227 hlens: TODO.
228 ys_in_pad: TODO.
229 ys_in_lens: Lengths of ys_in.
230 strm_idx: TODO.
231 """
232 if self.num_encs == 1:
233 hs_pad = [hs_pad]
234 hlens = [hlens]
235
236 # attention index for the attention module
237 # in SPA (speaker parallel attention),
238 # att_idx is used to select attention module. In other cases, it is 0.
239 att_idx = min(strm_idx, len(self.att_list) - 1)
240
241 # hlens should be list of list of integer
242 hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
243
244 # get dim, length info
245 olength = ys_in_pad.size(1)
246
247 # initialization
248 c_list = [self.zero_state(hs_pad[0])]
249 z_list = [self.zero_state(hs_pad[0])]
250 for _ in range(1, self.dlayers):
251 c_list.append(self.zero_state(hs_pad[0]))
252 z_list.append(self.zero_state(hs_pad[0]))
253 z_all = []
254 if self.num_encs == 1:
255 att_w = None
256 self.att_list[att_idx].reset() # reset pre-computation of h
257 else:
258 att_w_list = [None] * (self.num_encs + 1) # atts + han
259 att_c_list = [None] * self.num_encs # atts
260 for idx in range(self.num_encs + 1):
261 # reset pre-computation of h in atts and han
262 self.att_list[idx].reset()
263
264 # pre-computation of embedding
265 eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
266
267 # loop for an output sequence
268 for i in range(olength):
269 if self.num_encs == 1:
270 att_c, att_w = self.att_list[att_idx](
271 hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
272 )
273 else:
274 for idx in range(self.num_encs):
275 att_c_list[idx], att_w_list[idx] = self.att_list[idx](
276 hs_pad[idx],
277 hlens[idx],

Callers

nothing calls this directly

Calls 7

zero_stateMethod · 0.95
rnn_forwardMethod · 0.95
to_deviceFunction · 0.90
make_pad_maskFunction · 0.90
resetMethod · 0.45
outputMethod · 0.45
argmaxMethod · 0.45

Tested by

no test coverage detected