MCPcopy Index your code
hub / github.com/modelscope/FunASR / recognize_beam

Method recognize_beam

funasr/models/language_model/rnn/decoders.py:342–631  ·  view source on GitHub ↗

beam search implementation :param torch.Tensor h: encoder hidden state (T, eprojs) [in multi-encoder case, list of torch.Tensor, [(T1, eprojs), (T2, eprojs), ...] ] :param torch.Tensor lpz: ctc log softmax output (T, od

(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0)

Source from the content-addressed store, hash-verified

340 return self.loss, acc, ppl
341
342 def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0):
343 """beam search implementation
344
345 :param torch.Tensor h: encoder hidden state (T, eprojs)
346 [in multi-encoder case, list of torch.Tensor,
347 [(T1, eprojs), (T2, eprojs), ...] ]
348 :param torch.Tensor lpz: ctc log softmax output (T, odim)
349 [in multi-encoder case, list of torch.Tensor,
350 [(T1, odim), (T2, odim), ...] ]
351 :param Namespace recog_args: argument Namespace containing options
352 :param char_list: list of character strings
353 :param torch.nn.Module rnnlm: language module
354 :param int strm_idx:
355 stream index for speaker parallel attention in multi-speaker case
356 :return: N-best decoding results
357 :rtype: list of dicts
358 """
359 # to support mutiple encoder asr mode, in single encoder mode,
360 # convert torch.Tensor to List of torch.Tensor
361 if self.num_encs == 1:
362 h = [h]
363 lpz = [lpz]
364 if self.num_encs > 1 and lpz is None:
365 lpz = [lpz] * self.num_encs
366
367 for idx in range(self.num_encs):
368 logging.info(
369 "Number of Encoder:{}; enc{}: input lengths: {}.".format(
370 self.num_encs, idx + 1, h[0].size(0)
371 )
372 )
373 att_idx = min(strm_idx, len(self.att) - 1)
374 # initialization
375 c_list = [self.zero_state(h[0].unsqueeze(0))]
376 z_list = [self.zero_state(h[0].unsqueeze(0))]
377 for _ in six.moves.range(1, self.dlayers):
378 c_list.append(self.zero_state(h[0].unsqueeze(0)))
379 z_list.append(self.zero_state(h[0].unsqueeze(0)))
380 if self.num_encs == 1:
381 a = None
382 self.att[att_idx].reset() # reset pre-computation of h
383 else:
384 a = [None] * (self.num_encs + 1) # atts + han
385 att_w_list = [None] * (self.num_encs + 1) # atts + han
386 att_c_list = [None] * (self.num_encs) # atts
387 for idx in range(self.num_encs + 1):
388 self.att[idx].reset() # reset pre-computation of h in atts and han
389
390 # search parms
391 beam = recog_args.beam_size
392 penalty = recog_args.penalty
393 ctc_weight = getattr(recog_args, "ctc_weight", False) # for NMT
394
395 if lpz[0] is not None and self.num_encs > 1:
396 # weights-ctc,
397 # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
398 weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
399 recog_args.weights_ctc_dec

Callers

nothing calls this directly

Calls 8

zero_stateMethod · 0.95
rnn_forwardMethod · 0.95
CTCPrefixScoreClass · 0.90
end_detectFunction · 0.90
initial_stateMethod · 0.80
resetMethod · 0.45
outputMethod · 0.45
log_softmaxMethod · 0.45

Tested by

no test coverage detected