Args: input (torch.Tensor): Input tensor (B, T, D) in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs, {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
(
self,
input: torch.Tensor,
*args,
)
| 426 | pass |
| 427 | |
| 428 | def forward( |
| 429 | self, |
| 430 | input: torch.Tensor, |
| 431 | *args, |
| 432 | ): |
| 433 | """ |
| 434 | Args: |
| 435 | input (torch.Tensor): Input tensor (B, T, D) |
| 436 | in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs, |
| 437 | {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame |
| 438 | """ |
| 439 | |
| 440 | x = self.in_linear1(input) |
| 441 | x = self.in_linear2(x) |
| 442 | x = self.relu(x) |
| 443 | # x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn |
| 444 | out_caches = list() |
| 445 | for i, d in enumerate(self.fsmn): |
| 446 | in_cache = args[i] |
| 447 | x, out_cache = d(x, in_cache) |
| 448 | out_caches.append(out_cache) |
| 449 | x = self.out_linear1(x) |
| 450 | x = self.out_linear2(x) |
| 451 | x = self.softmax(x) |
| 452 | |
| 453 | return x, out_caches |