MCPcopy Index your code
hub / github.com/apache/tvm / load_tensor_cache

Function load_tensor_cache

python/tvm/contrib/tvmjs.py:333–394  ·  view source on GitHub ↗

Load the tensor cache from the directory or json. Parameters ---------- cachepath: str Path to the location or json file. device: tvm.runtime.Device The device we would like to load the data from.

(cachepath: str, device: tvm.runtime.Device)

Source from the content-addressed store, hash-verified

331
332
333def load_tensor_cache(cachepath: str, device: tvm.runtime.Device):
334 """Load the tensor cache from the directory or json.
335
336
337 Parameters
338 ----------
339 cachepath: str
340 Path to the location or json file.
341
342 device: tvm.runtime.Device
343 The device we would like to load the data from.
344 """
345 if not cachepath.endswith(".json"):
346 cachepath = os.path.join(cachepath, "tensor-cache.json")
347
348 cachedir = os.path.dirname(cachepath)
349 json_info = json.loads(open(cachepath).read())
350 result_dict = {}
351
352 for shard_rec in json_info["records"]:
353 data_path = shard_rec["dataPath"]
354 full_data_path = os.path.join(cachedir, data_path)
355 raw_data = open(full_data_path, "rb").read()
356 assert shard_rec["format"] == "raw-shard"
357 assert shard_rec["nbytes"] == len(raw_data)
358
359 for rec in shard_rec["records"]:
360 name = rec["name"]
361 shape = rec["shape"]
362 dtype = rec["dtype"]
363 encode_format = rec["format"]
364 offset = rec["byteOffset"]
365 nbytes = rec["nbytes"]
366
367 arr = tvm.runtime.empty(shape, dtype, device=device)
368 assert offset + nbytes <= len(raw_data)
369 buffer_source = raw_data[offset : offset + nbytes]
370 if dtype == "float8_e4m3fn":
371 if ml_dtypes is not None:
372 dtype = ml_dtypes.float8_e4m3fn
373 else:
374 raise RuntimeError(
375 "ml_dtypes is not installed, cannot convert float8_e4m3fn array to numpy."
376 )
377 if dtype == "float8_e5m2":
378 if ml_dtypes is not None:
379 dtype = ml_dtypes.float8_e5m2
380 else:
381 raise RuntimeError(
382 "ml_dtypes is not installed, cannot convert float8_e5m2 array to numpy."
383 )
384 if encode_format == "f32-to-bf16" and dtype == "float32":
385 data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape)
386 arr.copyfrom(_convert_bf16_to_f32(data))
387 elif dtype == "bfloat16":
388 data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape)
389 arr.copyfrom(data)
390 else:

Callers

nothing calls this directly

Calls 6

_convert_bf16_to_f32Function · 0.85
copyfromMethod · 0.80
readMethod · 0.65
joinMethod · 0.45
emptyMethod · 0.45
reshapeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…