Setup KV caches and return a causal mask.
(self, max_batch_size: int)
| 118 | self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size)) |
| 119 | |
| 120 | def setup_caches(self, max_batch_size: int) -> torch.Tensor: |
| 121 | """Setup KV caches and return a causal mask.""" |
| 122 | dtype = next(self.parameters()).dtype |
| 123 | device = next(self.parameters()).device |
| 124 | |
| 125 | with device: |
| 126 | self.backbone.setup_caches(max_batch_size, dtype) |
| 127 | self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks) |
| 128 | |
| 129 | self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device)) |
| 130 | self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device)) |
| 131 | |
| 132 | def generate_frame( |
| 133 | self, |
no test coverage detected