| 145 | return encoded_frames |
| 146 | |
| 147 | def _encode_frame(self, x: torch.Tensor) -> EncodedFrame: |
| 148 | length = x.shape[-1] |
| 149 | duration = length / self.sample_rate |
| 150 | assert self.segment is None or duration <= 1e-5 + self.segment |
| 151 | |
| 152 | if self.normalize: |
| 153 | mono = x.mean(dim=1, keepdim=True) |
| 154 | volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() |
| 155 | scale = 1e-8 + volume |
| 156 | x = x / scale |
| 157 | scale = scale.view(-1, 1) |
| 158 | else: |
| 159 | scale = None |
| 160 | |
| 161 | emb = self.encoder(x) |
| 162 | codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth) |
| 163 | codes = codes.transpose(0, 1) |
| 164 | # codes is [B, K, T], with T frames, K nb of codebooks. |
| 165 | return codes, scale |
| 166 | |
| 167 | def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor: |
| 168 | """Decode the given frames into a waveform. |