Write a minimal safetensors file from {name: np.ndarray} dict. Values can be np.ndarray (auto-dtype) or (raw_bytes, shape, sf_dtype) tuples for dtypes numpy doesn't support (F8_E4M3, F8_E8M0, I8).
(path, tensors)
| 1240 | |
| 1241 | |
| 1242 | def _write_safetensors(path, tensors): |
| 1243 | """Write a minimal safetensors file from {name: np.ndarray} dict. |
| 1244 | |
| 1245 | Values can be np.ndarray (auto-dtype) or (raw_bytes, shape, sf_dtype) tuples |
| 1246 | for dtypes numpy doesn't support (F8_E4M3, F8_E8M0, I8).""" |
| 1247 | import json |
| 1248 | import struct |
| 1249 | |
| 1250 | header = {} |
| 1251 | data_parts = [] |
| 1252 | offset = 0 |
| 1253 | dtype_map = {np.float16: "F16", np.float32: "F32", np.dtype("<f2"): "F16"} |
| 1254 | for name, val in tensors.items(): |
| 1255 | if isinstance(val, tuple): |
| 1256 | raw, shape, sf_dtype = val |
| 1257 | else: |
| 1258 | raw = val.tobytes() |
| 1259 | shape = list(val.shape) |
| 1260 | sf_dtype = dtype_map.get(val.dtype, "F16") |
| 1261 | header[name] = { |
| 1262 | "dtype": sf_dtype, |
| 1263 | "shape": list(shape), |
| 1264 | "data_offsets": [offset, offset + len(raw)], |
| 1265 | } |
| 1266 | data_parts.append(raw) |
| 1267 | offset += len(raw) |
| 1268 | hdr_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8") |
| 1269 | with open(path, "wb") as f: |
| 1270 | f.write(struct.pack("<Q", len(hdr_bytes))) |
| 1271 | f.write(hdr_bytes) |
| 1272 | for part in data_parts: |
| 1273 | f.write(part) |
| 1274 | |
| 1275 | |
| 1276 | @pytest.mark.skipif(not HAS_MLX, reason="MLX not available") |
no test coverage detected