(self, x, cache_x=None)
| 31 | self.padding = (0, 0, 0) |
| 32 | |
| 33 | def forward(self, x, cache_x=None): |
| 34 | padding = list(self._padding) |
| 35 | if cache_x is not None and self._padding[4] > 0: |
| 36 | cache_x = cache_x.to(x.device) |
| 37 | x = torch.cat([cache_x, x], dim=2) |
| 38 | padding[4] -= cache_x.shape[2] |
| 39 | x = F.pad(x, padding) |
| 40 | |
| 41 | return super().forward(x) |
| 42 | |
| 43 | |
| 44 | class RMS_norm(nn.Module): |