(self)
| 373 | self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1) |
| 374 | |
| 375 | def deinit_cache(self): |
| 376 | self.cache = [None] * len(self.cache) |
| 377 | |
| 378 | def forward(self, x: Tensor) -> Tensor: |
| 379 | for l, layer in enumerate(self.layers): |