Load the tensor cache from the directory or json. Parameters ---------- cachepath: str Path to the location or json file. device: tvm.runtime.Device The device we would like to load the data from.
(cachepath: str, device: tvm.runtime.Device)
| 331 | |
| 332 | |
| 333 | def load_tensor_cache(cachepath: str, device: tvm.runtime.Device): |
| 334 | """Load the tensor cache from the directory or json. |
| 335 | |
| 336 | |
| 337 | Parameters |
| 338 | ---------- |
| 339 | cachepath: str |
| 340 | Path to the location or json file. |
| 341 | |
| 342 | device: tvm.runtime.Device |
| 343 | The device we would like to load the data from. |
| 344 | """ |
| 345 | if not cachepath.endswith(".json"): |
| 346 | cachepath = os.path.join(cachepath, "tensor-cache.json") |
| 347 | |
| 348 | cachedir = os.path.dirname(cachepath) |
| 349 | json_info = json.loads(open(cachepath).read()) |
| 350 | result_dict = {} |
| 351 | |
| 352 | for shard_rec in json_info["records"]: |
| 353 | data_path = shard_rec["dataPath"] |
| 354 | full_data_path = os.path.join(cachedir, data_path) |
| 355 | raw_data = open(full_data_path, "rb").read() |
| 356 | assert shard_rec["format"] == "raw-shard" |
| 357 | assert shard_rec["nbytes"] == len(raw_data) |
| 358 | |
| 359 | for rec in shard_rec["records"]: |
| 360 | name = rec["name"] |
| 361 | shape = rec["shape"] |
| 362 | dtype = rec["dtype"] |
| 363 | encode_format = rec["format"] |
| 364 | offset = rec["byteOffset"] |
| 365 | nbytes = rec["nbytes"] |
| 366 | |
| 367 | arr = tvm.runtime.empty(shape, dtype, device=device) |
| 368 | assert offset + nbytes <= len(raw_data) |
| 369 | buffer_source = raw_data[offset : offset + nbytes] |
| 370 | if dtype == "float8_e4m3fn": |
| 371 | if ml_dtypes is not None: |
| 372 | dtype = ml_dtypes.float8_e4m3fn |
| 373 | else: |
| 374 | raise RuntimeError( |
| 375 | "ml_dtypes is not installed, cannot convert float8_e4m3fn array to numpy." |
| 376 | ) |
| 377 | if dtype == "float8_e5m2": |
| 378 | if ml_dtypes is not None: |
| 379 | dtype = ml_dtypes.float8_e5m2 |
| 380 | else: |
| 381 | raise RuntimeError( |
| 382 | "ml_dtypes is not installed, cannot convert float8_e5m2 array to numpy." |
| 383 | ) |
| 384 | if encode_format == "f32-to-bf16" and dtype == "float32": |
| 385 | data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape) |
| 386 | arr.copyfrom(_convert_bf16_to_f32(data)) |
| 387 | elif dtype == "bfloat16": |
| 388 | data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape) |
| 389 | arr.copyfrom(data) |
| 390 | else: |