Decoder forward :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) [in multi-encoder case, list of torch.Tensor, [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None)
| 175 | return z_list, c_list |
| 176 | |
| 177 | def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None): |
| 178 | """Decoder forward |
| 179 | |
| 180 | :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) |
| 181 | [in multi-encoder case, |
| 182 | list of torch.Tensor, |
| 183 | [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ] |
| 184 | :param torch.Tensor hlens: batch of lengths of hidden state sequences (B) |
| 185 | [in multi-encoder case, list of torch.Tensor, |
| 186 | [(B), (B), ..., ] |
| 187 | :param torch.Tensor ys_pad: batch of padded character id sequence tensor |
| 188 | (B, Lmax) |
| 189 | :param int strm_idx: stream index indicates the index of decoding stream. |
| 190 | :param torch.Tensor lang_ids: batch of target language id tensor (B, 1) |
| 191 | :return: attention loss value |
| 192 | :rtype: torch.Tensor |
| 193 | :return: accuracy |
| 194 | :rtype: float |
| 195 | """ |
| 196 | # to support mutiple encoder asr mode, in single encoder mode, |
| 197 | # convert torch.Tensor to List of torch.Tensor |
| 198 | if self.num_encs == 1: |
| 199 | hs_pad = [hs_pad] |
| 200 | hlens = [hlens] |
| 201 | |
| 202 | # TODO(kan-bayashi): need to make more smart way |
| 203 | ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys |
| 204 | # attention index for the attention module |
| 205 | # in SPA (speaker parallel attention), |
| 206 | # att_idx is used to select attention module. In other cases, it is 0. |
| 207 | att_idx = min(strm_idx, len(self.att) - 1) |
| 208 | |
| 209 | # hlens should be list of list of integer |
| 210 | hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)] |
| 211 | |
| 212 | self.loss = None |
| 213 | # prepare input and output word sequences with sos/eos IDs |
| 214 | eos = ys[0].new([self.eos]) |
| 215 | sos = ys[0].new([self.sos]) |
| 216 | if self.replace_sos: |
| 217 | ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)] |
| 218 | else: |
| 219 | ys_in = [torch.cat([sos, y], dim=0) for y in ys] |
| 220 | ys_out = [torch.cat([y, eos], dim=0) for y in ys] |
| 221 | |
| 222 | # padding for ys with -1 |
| 223 | # pys: utt x olen |
| 224 | ys_in_pad = pad_list(ys_in, self.eos) |
| 225 | ys_out_pad = pad_list(ys_out, self.ignore_id) |
| 226 | |
| 227 | # get dim, length info |
| 228 | batch = ys_out_pad.size(0) |
| 229 | olength = ys_out_pad.size(1) |
| 230 | for idx in range(self.num_encs): |
| 231 | logging.info( |
| 232 | self.__class__.__name__ |
| 233 | + "Number of Encoder:{}; enc{}: input lengths: {}.".format( |
| 234 | self.num_encs, idx + 1, hlens[idx] |
nothing calls this directly
no test coverage detected