| 14 | |
| 15 | @dataclass |
| 16 | class DACFile: |
| 17 | codes: torch.Tensor |
| 18 | |
| 19 | # Metadata |
| 20 | chunk_length: int |
| 21 | original_length: int |
| 22 | input_db: float |
| 23 | channels: int |
| 24 | sample_rate: int |
| 25 | padding: bool |
| 26 | dac_version: str |
| 27 | |
| 28 | def save(self, path): |
| 29 | artifacts = { |
| 30 | "codes": self.codes.numpy().astype(np.uint16), |
| 31 | "metadata": { |
| 32 | "input_db": self.input_db.numpy().astype(np.float32), |
| 33 | "original_length": self.original_length, |
| 34 | "sample_rate": self.sample_rate, |
| 35 | "chunk_length": self.chunk_length, |
| 36 | "channels": self.channels, |
| 37 | "padding": self.padding, |
| 38 | "dac_version": SUPPORTED_VERSIONS[-1], |
| 39 | }, |
| 40 | } |
| 41 | path = Path(path).with_suffix(".dac") |
| 42 | with open(path, "wb") as f: |
| 43 | np.save(f, artifacts) |
| 44 | return path |
| 45 | |
| 46 | @classmethod |
| 47 | def load(cls, path): |
| 48 | artifacts = np.load(path, allow_pickle=True)[()] |
| 49 | codes = torch.from_numpy(artifacts["codes"].astype(int)) |
| 50 | if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: |
| 51 | raise RuntimeError( |
| 52 | f"Given file {path} can't be loaded with this version of descript-audio-codec." |
| 53 | ) |
| 54 | return cls(codes=codes, **artifacts["metadata"]) |
| 55 | |
| 56 | |
| 57 | class CodecMixin: |