Forward pass for training. Args: hs_pad: TODO. hlens: TODO. ys_in_pad: TODO. ys_in_lens: Lengths of ys_in. strm_idx: TODO.
(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0)
| 218 | return z_list, c_list |
| 219 | |
| 220 | def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0): |
| 221 | # to support mutiple encoder asr mode, in single encoder mode, |
| 222 | # convert torch.Tensor to List of torch.Tensor |
| 223 | """Forward pass for training. |
| 224 | |
| 225 | Args: |
| 226 | hs_pad: TODO. |
| 227 | hlens: TODO. |
| 228 | ys_in_pad: TODO. |
| 229 | ys_in_lens: Lengths of ys_in. |
| 230 | strm_idx: TODO. |
| 231 | """ |
| 232 | if self.num_encs == 1: |
| 233 | hs_pad = [hs_pad] |
| 234 | hlens = [hlens] |
| 235 | |
| 236 | # attention index for the attention module |
| 237 | # in SPA (speaker parallel attention), |
| 238 | # att_idx is used to select attention module. In other cases, it is 0. |
| 239 | att_idx = min(strm_idx, len(self.att_list) - 1) |
| 240 | |
| 241 | # hlens should be list of list of integer |
| 242 | hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)] |
| 243 | |
| 244 | # get dim, length info |
| 245 | olength = ys_in_pad.size(1) |
| 246 | |
| 247 | # initialization |
| 248 | c_list = [self.zero_state(hs_pad[0])] |
| 249 | z_list = [self.zero_state(hs_pad[0])] |
| 250 | for _ in range(1, self.dlayers): |
| 251 | c_list.append(self.zero_state(hs_pad[0])) |
| 252 | z_list.append(self.zero_state(hs_pad[0])) |
| 253 | z_all = [] |
| 254 | if self.num_encs == 1: |
| 255 | att_w = None |
| 256 | self.att_list[att_idx].reset() # reset pre-computation of h |
| 257 | else: |
| 258 | att_w_list = [None] * (self.num_encs + 1) # atts + han |
| 259 | att_c_list = [None] * self.num_encs # atts |
| 260 | for idx in range(self.num_encs + 1): |
| 261 | # reset pre-computation of h in atts and han |
| 262 | self.att_list[idx].reset() |
| 263 | |
| 264 | # pre-computation of embedding |
| 265 | eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim |
| 266 | |
| 267 | # loop for an output sequence |
| 268 | for i in range(olength): |
| 269 | if self.num_encs == 1: |
| 270 | att_c, att_w = self.att_list[att_idx]( |
| 271 | hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w |
| 272 | ) |
| 273 | else: |
| 274 | for idx in range(self.num_encs): |
| 275 | att_c_list[idx], att_w_list[idx] = self.att_list[idx]( |
| 276 | hs_pad[idx], |
| 277 | hlens[idx], |
nothing calls this directly
no test coverage detected