MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / serialize_managed_weights

Function serialize_managed_weights

tensorrt_llm/builder.py:893–935  ·  view source on GitHub ↗
(managed_weights: dict[str, np.ndarray],
                              path: str | Path,
                              metadata=None)

Source from the content-addressed store, hash-verified

891
892
893def serialize_managed_weights(managed_weights: dict[str, np.ndarray],
894 path: str | Path,
895 metadata=None) -> None:
896 header = {}
897 if metadata is not None:
898 header["__metadata__"] = metadata
899 begin = 0
900 for name, value in managed_weights.items():
901 size = value.size * value.itemsize
902 if value.dtype == np.float32:
903 dtype = "F32"
904 elif value.dtype == np.float16:
905 dtype = "F16"
906 elif value.dtype == np_bfloat16:
907 dtype = "BF16"
908 elif value.dtype == np_float8:
909 dtype = "F8_E4M3"
910 elif value.dtype == np.int64:
911 dtype = "I64"
912 elif value.dtype == np.int32:
913 dtype = "I32"
914 elif value.dtype == np.int8:
915 dtype = "I8"
916 else:
917 raise RuntimeError(f"Unsupported dtype: {value.dtype}")
918 header[name] = {
919 "dtype": dtype,
920 "shape": value.shape,
921 "data_offsets": [begin, begin + size],
922 }
923 begin += size
924
925 header_json = json.dumps(header)
926 header_json_len = len(header_json)
927 with open(path, "wb") as f:
928 logger.info(
929 f"Serializing {len(managed_weights)} managed weights to {path}...")
930 f.write(header_json_len.to_bytes(8, byteorder="little"))
931 f.write(header_json.encode())
932 for name, value in managed_weights.items():
933 logger.debug(f"Serializing managed weight: {name}")
934 buf = value.data
935 f.write(buf)
936
937
938def deserialize_managed_weights(path: str | Path) -> dict[str, np.ndarray]:

Callers 1

saveMethod · 0.85

Calls 3

infoMethod · 0.45
encodeMethod · 0.45
debugMethod · 0.45

Tested by

no test coverage detected