Stacks the bytes for a weight group into a flat byte array. Args: group: A list of weight entries. Returns: A type: (group_bytes, total_bytes, weights_entries, group_bytes_writer) group_bytes: The stacked bytes for the group, as a BytesIO() stream. total_bytes: A number represen
(group)
| 239 | return data.tobytes() |
| 240 | |
| 241 | def _stack_group_bytes(group): |
| 242 | """Stacks the bytes for a weight group into a flat byte array. |
| 243 | |
| 244 | Args: |
| 245 | group: A list of weight entries. |
| 246 | Returns: |
| 247 | A type: (group_bytes, total_bytes, weights_entries, group_bytes_writer) |
| 248 | group_bytes: The stacked bytes for the group, as a BytesIO() stream. |
| 249 | total_bytes: A number representing the total size of the byte buffer. |
| 250 | groups_bytes_writer: The io.BufferedWriter object. Returned so that |
| 251 | group_bytes does not get garbage collected and closed. |
| 252 | |
| 253 | """ |
| 254 | group_bytes = io.BytesIO() |
| 255 | group_bytes_writer = io.BufferedWriter(group_bytes) |
| 256 | total_bytes = 0 |
| 257 | |
| 258 | for entry in group: |
| 259 | _assert_valid_weight_entry(entry) |
| 260 | data = entry['data'] |
| 261 | |
| 262 | if data.dtype == object: |
| 263 | data_bytes = _serialize_string_array(data) |
| 264 | else: |
| 265 | data_bytes = _serialize_numeric_array(data) |
| 266 | group_bytes_writer.write(data_bytes) |
| 267 | total_bytes += len(data_bytes) |
| 268 | |
| 269 | group_bytes_writer.flush() |
| 270 | group_bytes.seek(0) |
| 271 | |
| 272 | # NOTE: We must return the bytes writer here, otherwise it goes out of scope |
| 273 | # and python closes the IO operation. |
| 274 | return (group_bytes, total_bytes, group_bytes_writer) |
| 275 | |
| 276 | |
| 277 | def _shard_group_bytes_to_disk( |
no test coverage detected
searching dependent graphs…