Init state. Args: x: TODO.
(self, x)
| 311 | return z_all, ys_in_lens |
| 312 | |
| 313 | def init_state(self, x): |
| 314 | # to support mutiple encoder asr mode, in single encoder mode, |
| 315 | # convert torch.Tensor to List of torch.Tensor |
| 316 | """Init state. |
| 317 | |
| 318 | Args: |
| 319 | x: TODO. |
| 320 | """ |
| 321 | if self.num_encs == 1: |
| 322 | x = [x] |
| 323 | |
| 324 | c_list = [self.zero_state(x[0].unsqueeze(0))] |
| 325 | z_list = [self.zero_state(x[0].unsqueeze(0))] |
| 326 | for _ in range(1, self.dlayers): |
| 327 | c_list.append(self.zero_state(x[0].unsqueeze(0))) |
| 328 | z_list.append(self.zero_state(x[0].unsqueeze(0))) |
| 329 | # TODO(karita): support strm_index for `asr_mix` |
| 330 | strm_index = 0 |
| 331 | att_idx = min(strm_index, len(self.att_list) - 1) |
| 332 | if self.num_encs == 1: |
| 333 | a = None |
| 334 | self.att_list[att_idx].reset() # reset pre-computation of h |
| 335 | else: |
| 336 | a = [None] * (self.num_encs + 1) # atts + han |
| 337 | for idx in range(self.num_encs + 1): |
| 338 | # reset pre-computation of h in atts and han |
| 339 | self.att_list[idx].reset() |
| 340 | return dict( |
| 341 | c_prev=c_list[:], |
| 342 | z_prev=z_list[:], |
| 343 | a_prev=a, |
| 344 | workspace=(att_idx, z_list, c_list), |
| 345 | ) |
| 346 | |
| 347 | def score(self, yseq, state, x): |
| 348 | # to support mutiple encoder asr mode, in single encoder mode, |
no test coverage detected