Encoder forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...) :return: batch of hidden state
(self, xs_pad, ilens, prev_states=None)
| 333 | self.conv_subsampling_factor = 1 |
| 334 | |
| 335 | def forward(self, xs_pad, ilens, prev_states=None): |
| 336 | """Encoder forward |
| 337 | |
| 338 | :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) |
| 339 | :param torch.Tensor ilens: batch of lengths of input sequences (B) |
| 340 | :param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...) |
| 341 | :return: batch of hidden state sequences (B, Tmax, eprojs) |
| 342 | :rtype: torch.Tensor |
| 343 | """ |
| 344 | if prev_states is None: |
| 345 | prev_states = [None] * len(self.enc) |
| 346 | assert len(prev_states) == len(self.enc) |
| 347 | |
| 348 | current_states = [] |
| 349 | for module, prev_state in zip(self.enc, prev_states): |
| 350 | xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state) |
| 351 | current_states.append(states) |
| 352 | |
| 353 | # make mask to remove bias value in padded part |
| 354 | mask = to_device(xs_pad, make_pad_mask(ilens).unsqueeze(-1)) |
| 355 | |
| 356 | return xs_pad.masked_fill(mask, 0.0), ilens, current_states |
| 357 | |
| 358 | |
| 359 | def encoder_for(args, idim, subsample): |
nothing calls this directly
no test coverage detected