(self, token_data)
| 275 | |
| 276 | @T.no_grad() |
| 277 | def untokenize(self, token_data): |
| 278 | if exists(self.audio_latent_cache): |
| 279 | token_data = T.cat([self.audio_latent_cache, token_data], dim=1) |
| 280 | self.audio_latent_cache = token_data[:, -(6*8):] |
| 281 | elif self.use_audio_cache: |
| 282 | self.audio_latent_cache = token_data[:, -(6*8):] |
| 283 | |
| 284 | if token_data.shape[-1] == 2*self.c.latent_size: |
| 285 | dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size]) |
| 286 | dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:]) |
| 287 | return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):] |
| 288 | else: |
| 289 | return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):] |
| 290 | |
| 291 | def init_cache(self, bsize, device, dtype, length:int=None): |
| 292 | cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim] |
no test coverage detected