(self, batch)
| 48 | return ret |
| 49 | |
| 50 | def _get_audio_embed(self, batch): |
| 51 | # batch: [B, samples] |
| 52 | with torch.no_grad(): |
| 53 | audio_dict_list = [] |
| 54 | assert ( |
| 55 | self.sampling_rate == 32000 |
| 56 | ), "We only support 32000 sampling rate" |
| 57 | |
| 58 | # batch: [bs, 1, t-samples] |
| 59 | batch = torchaudio.functional.resample( |
| 60 | batch, orig_freq=self.sampling_rate, new_freq=48000 |
| 61 | ) |
| 62 | for waveform in self.batch_to_list(batch): |
| 63 | audio_dict = {} |
| 64 | audio_dict = get_audio_features( |
| 65 | audio_dict, |
| 66 | waveform, |
| 67 | 480000, |
| 68 | data_truncating="fusion", |
| 69 | data_filling="repeatpad", |
| 70 | audio_cfg=self.model_cfg["audio_cfg"], |
| 71 | ) |
| 72 | audio_dict_list.append(audio_dict) |
| 73 | # [bs, 512] |
| 74 | embed = self.model.get_audio_embedding(audio_dict_list) |
| 75 | |
| 76 | return embed.detach() |
| 77 | |
| 78 | def _get_text_embed(self, batch): |
| 79 | double_batch = False |
no test coverage detected