(self, input_audios: Tensor, input_audio_lengths: Tensor, audio_span_tokens: List)
| 425 | return x, bos, eos |
| 426 | |
| 427 | def encode(self, input_audios: Tensor, input_audio_lengths: Tensor, audio_span_tokens: List): |
| 428 | real_input_audio_lens = input_audio_lengths[:, 0].tolist() |
| 429 | max_len_in_batch = max(real_input_audio_lens) |
| 430 | padding_mask = torch.ones([input_audios.size(0), max_len_in_batch]).to(dtype=self.conv1.weight.dtype, |
| 431 | device=self.conv1.weight.device) |
| 432 | for index in range(len(input_audios)): |
| 433 | padding_mask[index, :input_audio_lengths[index][0].item()] = 0 |
| 434 | x, bos, eos = self(input_audios, padding_mask,input_audio_lengths) |
| 435 | output_audios = [] |
| 436 | for i in range(len(audio_span_tokens)): |
| 437 | audio_span = audio_span_tokens[i] |
| 438 | audio = x[i][:audio_span-2] |
| 439 | if bos is not None: |
| 440 | audio = torch.concat([bos, audio, eos]) |
| 441 | assert len(audio) == audio_span |
| 442 | output_audios.append(audio) |
| 443 | return output_audios |
no outgoing calls
no test coverage detected