(self, x: torch.Tensor)
| 187 | return out |
| 188 | |
| 189 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 190 | frames = self.encode(x) |
| 191 | return self.decode(frames)[:, :, :x.shape[-1]] |
| 192 | |
| 193 | def set_target_bandwidth(self, bandwidth: float): |
| 194 | if bandwidth not in self.target_bandwidths: |