Internal helper to shard ndarrays.
| 73 | |
| 74 | |
| 75 | class TensorCacheShardingManager: |
| 76 | """Internal helper to shard ndarrays.""" |
| 77 | |
| 78 | def __init__( |
| 79 | self, |
| 80 | cache_dir: str, |
| 81 | prefix: str, |
| 82 | shard_cap_nbytes: int, |
| 83 | initial_shard_records: Mapping[str, Any] | None = None, |
| 84 | ): |
| 85 | self.cache_dir = cache_dir |
| 86 | self.prefix = prefix |
| 87 | self.curr_records = [] |
| 88 | self.curr_data = bytearray() |
| 89 | self.shard_records = [] |
| 90 | self.shard_cap_nbytes = shard_cap_nbytes |
| 91 | self.counter = 0 |
| 92 | self.name_to_record: Mapping[str, tuple[int, Mapping[str, Any]]] = {} |
| 93 | self.updated_shards: set[int] = set() |
| 94 | |
| 95 | if initial_shard_records is not None: |
| 96 | self.shard_records = initial_shard_records |
| 97 | self.counter = len(initial_shard_records) |
| 98 | for idx, shard in enumerate(initial_shard_records): |
| 99 | for rec in shard["records"]: |
| 100 | self.name_to_record[rec["name"]] = (idx, rec) |
| 101 | |
| 102 | def append_or_update(self, data, name, shape, dtype, encode_format, allow_update: bool = False): |
| 103 | """Commit a record to the manager. |
| 104 | |
| 105 | Parameters |
| 106 | ---------- |
| 107 | data: bytes |
| 108 | Raw bytes to be appended. |
| 109 | |
| 110 | name: str |
| 111 | The name of the parameter |
| 112 | |
| 113 | shape: tuple |
| 114 | The shape of the array |
| 115 | |
| 116 | dtype: str |
| 117 | The dtype information |
| 118 | |
| 119 | encode_format: |
| 120 | The encode format of the entry |
| 121 | |
| 122 | allow_update: bool |
| 123 | If the record already exists, update the record. Otherwise, raise an error. |
| 124 | """ |
| 125 | rec = { |
| 126 | "name": name, |
| 127 | "shape": shape, |
| 128 | "dtype": dtype, |
| 129 | "format": encode_format, |
| 130 | "nbytes": len(data), |
| 131 | } |
| 132 | if name in self.name_to_record: |
no outgoing calls
no test coverage detected
searching dependent graphs…