| 891 | |
| 892 | |
| 893 | def serialize_managed_weights(managed_weights: dict[str, np.ndarray], |
| 894 | path: str | Path, |
| 895 | metadata=None) -> None: |
| 896 | header = {} |
| 897 | if metadata is not None: |
| 898 | header["__metadata__"] = metadata |
| 899 | begin = 0 |
| 900 | for name, value in managed_weights.items(): |
| 901 | size = value.size * value.itemsize |
| 902 | if value.dtype == np.float32: |
| 903 | dtype = "F32" |
| 904 | elif value.dtype == np.float16: |
| 905 | dtype = "F16" |
| 906 | elif value.dtype == np_bfloat16: |
| 907 | dtype = "BF16" |
| 908 | elif value.dtype == np_float8: |
| 909 | dtype = "F8_E4M3" |
| 910 | elif value.dtype == np.int64: |
| 911 | dtype = "I64" |
| 912 | elif value.dtype == np.int32: |
| 913 | dtype = "I32" |
| 914 | elif value.dtype == np.int8: |
| 915 | dtype = "I8" |
| 916 | else: |
| 917 | raise RuntimeError(f"Unsupported dtype: {value.dtype}") |
| 918 | header[name] = { |
| 919 | "dtype": dtype, |
| 920 | "shape": value.shape, |
| 921 | "data_offsets": [begin, begin + size], |
| 922 | } |
| 923 | begin += size |
| 924 | |
| 925 | header_json = json.dumps(header) |
| 926 | header_json_len = len(header_json) |
| 927 | with open(path, "wb") as f: |
| 928 | logger.info( |
| 929 | f"Serializing {len(managed_weights)} managed weights to {path}...") |
| 930 | f.write(header_json_len.to_bytes(8, byteorder="little")) |
| 931 | f.write(header_json.encode()) |
| 932 | for name, value in managed_weights.items(): |
| 933 | logger.debug(f"Serializing managed weight: {name}") |
| 934 | buf = value.data |
| 935 | f.write(buf) |
| 936 | |
| 937 | |
| 938 | def deserialize_managed_weights(path: str | Path) -> dict[str, np.ndarray]: |