(self, prompt_path)
| 59 | self.initialize_state(prompt_path) |
| 60 | |
| 61 | def initialize_state(self, prompt_path): |
| 62 | loaded_audio, sr = torchaudio.load(prompt_path) |
| 63 | self.replay_seconds = REPLAY_SECONDS |
| 64 | |
| 65 | if sr != SAMPLE_RATE: |
| 66 | resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE) |
| 67 | loaded_audio = resampler(loaded_audio) |
| 68 | |
| 69 | if loaded_audio.shape[0] == 1: |
| 70 | loaded_audio = loaded_audio.repeat(2, 1) |
| 71 | |
| 72 | audio_length = loaded_audio.shape[-1] |
| 73 | num_chunks = audio_length // 2000 |
| 74 | loaded_audio = loaded_audio[..., :num_chunks * 2000] |
| 75 | |
| 76 | self.loaded_audio = loaded_audio.to(device) |
| 77 | |
| 78 | with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode(): |
| 79 | self.model.init_cache(bsize=1, device=device, dtype=T.bfloat16, length=1024) |
| 80 | self.next_model_audio = self.model.next_audio_from_audio(self.loaded_audio.unsqueeze(0), temps=TEMPS) |
| 81 | self.prompt_buffer = None |
| 82 | self.prompt_position = 0 |
| 83 | self.chunks_until_live = int(self.replay_seconds * 8) |
| 84 | self.initialize_prompt_buffer() |
| 85 | print_colored("AudioProcessor state initialized", "green") |
| 86 | |
| 87 | def initialize_prompt_buffer(self): |
| 88 | self.recorded_audio = self.loaded_audio |
no test coverage detected