(model, audio_file, text, output_file, device='cuda', use_chunk=False)
| 18 | return model |
| 19 | |
| 20 | def separate_audio(model, audio_file, text, output_file, device='cuda', use_chunk=False): |
| 21 | print(f'Separating audio from [{audio_file}] with textual query: [{text}]') |
| 22 | mixture, fs = librosa.load(audio_file, sr=32000, mono=True) |
| 23 | with torch.no_grad(): |
| 24 | text = [text] |
| 25 | |
| 26 | conditions = model.query_encoder.get_query_embed( |
| 27 | modality='text', |
| 28 | text=text, |
| 29 | device=device |
| 30 | ) |
| 31 | |
| 32 | input_dict = { |
| 33 | "mixture": torch.Tensor(mixture)[None, None, :].to(device), |
| 34 | "condition": conditions, |
| 35 | } |
| 36 | |
| 37 | if use_chunk: |
| 38 | sep_segment = model.ss_model.chunk_inference(input_dict) |
| 39 | sep_segment = np.squeeze(sep_segment) |
| 40 | else: |
| 41 | sep_segment = model.ss_model(input_dict)["waveform"] |
| 42 | sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() |
| 43 | |
| 44 | write(output_file, 32000, np.round(sep_segment * 32767).astype(np.int16)) |
| 45 | print(f'Separated audio written to [{output_file}]') |
| 46 | |
| 47 | if __name__ == '__main__': |
| 48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
no test coverage detected