MCPcopy Index your code
hub / github.com/apache/tvm / TensorCacheShardingManager

Class TensorCacheShardingManager

python/tvm/contrib/tvmjs.py:75–199  ·  view source on GitHub ↗

Internal helper to shard ndarrays.

Source from the content-addressed store, hash-verified

73
74
75class TensorCacheShardingManager:
76 """Internal helper to shard ndarrays."""
77
78 def __init__(
79 self,
80 cache_dir: str,
81 prefix: str,
82 shard_cap_nbytes: int,
83 initial_shard_records: Mapping[str, Any] | None = None,
84 ):
85 self.cache_dir = cache_dir
86 self.prefix = prefix
87 self.curr_records = []
88 self.curr_data = bytearray()
89 self.shard_records = []
90 self.shard_cap_nbytes = shard_cap_nbytes
91 self.counter = 0
92 self.name_to_record: Mapping[str, tuple[int, Mapping[str, Any]]] = {}
93 self.updated_shards: set[int] = set()
94
95 if initial_shard_records is not None:
96 self.shard_records = initial_shard_records
97 self.counter = len(initial_shard_records)
98 for idx, shard in enumerate(initial_shard_records):
99 for rec in shard["records"]:
100 self.name_to_record[rec["name"]] = (idx, rec)
101
102 def append_or_update(self, data, name, shape, dtype, encode_format, allow_update: bool = False):
103 """Commit a record to the manager.
104
105 Parameters
106 ----------
107 data: bytes
108 Raw bytes to be appended.
109
110 name: str
111 The name of the parameter
112
113 shape: tuple
114 The shape of the array
115
116 dtype: str
117 The dtype information
118
119 encode_format:
120 The encode format of the entry
121
122 allow_update: bool
123 If the record already exists, update the record. Otherwise, raise an error.
124 """
125 rec = {
126 "name": name,
127 "shape": shape,
128 "dtype": dtype,
129 "format": encode_format,
130 "nbytes": len(data),
131 }
132 if name in self.name_to_record:

Callers 1

dump_tensor_cacheFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…