(self, codebook: int, tokens: torch.Tensor)
| 188 | self.decoder.reset_caches() |
| 189 | |
| 190 | def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor: |
| 191 | return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size) |
| 192 | |
| 193 | def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor: |
| 194 | text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2) |