Given a tensor `x`, returns a list of frames containing the discrete encoded codes for `x`, along with rescaling factors for each segment, when `self.normalize` is True. Each frames is a tuple `(codebook, scale)`, with `codebook` of shape `[B, K, T]`, with `K` the nu
(self, x: torch.Tensor)
| 120 | return max(1, int((1 - self.overlap) * segment_length)) |
| 121 | |
| 122 | def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]: |
| 123 | """Given a tensor `x`, returns a list of frames containing |
| 124 | the discrete encoded codes for `x`, along with rescaling factors |
| 125 | for each segment, when `self.normalize` is True. |
| 126 | |
| 127 | Each frames is a tuple `(codebook, scale)`, with `codebook` of |
| 128 | shape `[B, K, T]`, with `K` the number of codebooks. |
| 129 | """ |
| 130 | assert x.dim() == 3 |
| 131 | _, channels, length = x.shape |
| 132 | assert channels > 0 and channels <= 2 |
| 133 | segment_length = self.segment_length |
| 134 | if segment_length is None: |
| 135 | segment_length = length |
| 136 | stride = length |
| 137 | else: |
| 138 | stride = self.segment_stride # type: ignore |
| 139 | assert stride is not None |
| 140 | |
| 141 | encoded_frames: tp.List[EncodedFrame] = [] |
| 142 | for offset in range(0, length, stride): |
| 143 | frame = x[:, :, offset: offset + segment_length] |
| 144 | encoded_frames.append(self._encode_frame(frame)) |
| 145 | return encoded_frames |
| 146 | |
| 147 | def _encode_frame(self, x: torch.Tensor) -> EncodedFrame: |
| 148 | length = x.shape[-1] |
no test coverage detected