beam search implementation :param torch.Tensor h: encoder hidden state (T, eprojs) [in multi-encoder case, list of torch.Tensor, [(T1, eprojs), (T2, eprojs), ...] ] :param torch.Tensor lpz: ctc log softmax output (T, od
(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0)
| 340 | return self.loss, acc, ppl |
| 341 | |
| 342 | def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0): |
| 343 | """beam search implementation |
| 344 | |
| 345 | :param torch.Tensor h: encoder hidden state (T, eprojs) |
| 346 | [in multi-encoder case, list of torch.Tensor, |
| 347 | [(T1, eprojs), (T2, eprojs), ...] ] |
| 348 | :param torch.Tensor lpz: ctc log softmax output (T, odim) |
| 349 | [in multi-encoder case, list of torch.Tensor, |
| 350 | [(T1, odim), (T2, odim), ...] ] |
| 351 | :param Namespace recog_args: argument Namespace containing options |
| 352 | :param char_list: list of character strings |
| 353 | :param torch.nn.Module rnnlm: language module |
| 354 | :param int strm_idx: |
| 355 | stream index for speaker parallel attention in multi-speaker case |
| 356 | :return: N-best decoding results |
| 357 | :rtype: list of dicts |
| 358 | """ |
| 359 | # to support mutiple encoder asr mode, in single encoder mode, |
| 360 | # convert torch.Tensor to List of torch.Tensor |
| 361 | if self.num_encs == 1: |
| 362 | h = [h] |
| 363 | lpz = [lpz] |
| 364 | if self.num_encs > 1 and lpz is None: |
| 365 | lpz = [lpz] * self.num_encs |
| 366 | |
| 367 | for idx in range(self.num_encs): |
| 368 | logging.info( |
| 369 | "Number of Encoder:{}; enc{}: input lengths: {}.".format( |
| 370 | self.num_encs, idx + 1, h[0].size(0) |
| 371 | ) |
| 372 | ) |
| 373 | att_idx = min(strm_idx, len(self.att) - 1) |
| 374 | # initialization |
| 375 | c_list = [self.zero_state(h[0].unsqueeze(0))] |
| 376 | z_list = [self.zero_state(h[0].unsqueeze(0))] |
| 377 | for _ in six.moves.range(1, self.dlayers): |
| 378 | c_list.append(self.zero_state(h[0].unsqueeze(0))) |
| 379 | z_list.append(self.zero_state(h[0].unsqueeze(0))) |
| 380 | if self.num_encs == 1: |
| 381 | a = None |
| 382 | self.att[att_idx].reset() # reset pre-computation of h |
| 383 | else: |
| 384 | a = [None] * (self.num_encs + 1) # atts + han |
| 385 | att_w_list = [None] * (self.num_encs + 1) # atts + han |
| 386 | att_c_list = [None] * (self.num_encs) # atts |
| 387 | for idx in range(self.num_encs + 1): |
| 388 | self.att[idx].reset() # reset pre-computation of h in atts and han |
| 389 | |
| 390 | # search parms |
| 391 | beam = recog_args.beam_size |
| 392 | penalty = recog_args.penalty |
| 393 | ctc_weight = getattr(recog_args, "ctc_weight", False) # for NMT |
| 394 | |
| 395 | if lpz[0] is not None and self.num_encs > 1: |
| 396 | # weights-ctc, |
| 397 | # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss |
| 398 | weights_ctc_dec = recog_args.weights_ctc_dec / np.sum( |
| 399 | recog_args.weights_ctc_dec |
nothing calls this directly
no test coverage detected