(cls, batch_size: int, config: DiaConfig, device: torch.device)
| 187 | |
| 188 | @classmethod |
| 189 | def new(cls, batch_size: int, config: DiaConfig, device: torch.device) -> "DecoderOutput": |
| 190 | max_audio_len = config.decoder_config.max_position_embeddings |
| 191 | return cls( |
| 192 | generated_tokens=torch.full( |
| 193 | (batch_size, max_audio_len, config.decoder_config.num_channels), |
| 194 | fill_value=-1, |
| 195 | dtype=torch.int, |
| 196 | device=device, |
| 197 | ), |
| 198 | prefill_steps=[], |
| 199 | ) |
| 200 | |
| 201 | def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor: |
| 202 | if step_to is None: |