MCPcopy
hub / github.com/jundot/omlx / _write_safetensors

Function _write_safetensors

tests/test_oq.py:1242–1273  ·  view source on GitHub ↗

Write a minimal safetensors file from {name: np.ndarray} dict. Values can be np.ndarray (auto-dtype) or (raw_bytes, shape, sf_dtype) tuples for dtypes numpy doesn't support (F8_E4M3, F8_E8M0, I8).

(path, tensors)

Source from the content-addressed store, hash-verified

1240
1241
1242def _write_safetensors(path, tensors):
1243 """Write a minimal safetensors file from {name: np.ndarray} dict.
1244
1245 Values can be np.ndarray (auto-dtype) or (raw_bytes, shape, sf_dtype) tuples
1246 for dtypes numpy doesn't support (F8_E4M3, F8_E8M0, I8)."""
1247 import json
1248 import struct
1249
1250 header = {}
1251 data_parts = []
1252 offset = 0
1253 dtype_map = {np.float16: "F16", np.float32: "F32", np.dtype("<f2"): "F16"}
1254 for name, val in tensors.items():
1255 if isinstance(val, tuple):
1256 raw, shape, sf_dtype = val
1257 else:
1258 raw = val.tobytes()
1259 shape = list(val.shape)
1260 sf_dtype = dtype_map.get(val.dtype, "F16")
1261 header[name] = {
1262 "dtype": sf_dtype,
1263 "shape": list(shape),
1264 "data_offsets": [offset, offset + len(raw)],
1265 }
1266 data_parts.append(raw)
1267 offset += len(raw)
1268 hdr_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8")
1269 with open(path, "wb") as f:
1270 f.write(struct.pack("<Q", len(hdr_bytes)))
1271 f.write(hdr_bytes)
1272 for part in data_parts:
1273 f.write(part)
1274
1275
1276@pytest.mark.skipif(not HAS_MLX, reason="MLX not available")

Calls 4

appendMethod · 0.80
itemsMethod · 0.45
getMethod · 0.45
encodeMethod · 0.45

Tested by

no test coverage detected