Load weight values according to a TensorFlow.js weights manifest. Args: weights_manifest: A TensorFlow.js-format weights manifest (a JSON array). base_path: Base path prefix for the weights files. flatten: Whether all the weight groups in the return value are to be flattened as
(weights_manifest, base_path, flatten=False)
| 33 | STRING_LENGTH_DTYPE = np.dtype('uint32').newbyteorder('<') |
| 34 | |
| 35 | def read_weights(weights_manifest, base_path, flatten=False): |
| 36 | """Load weight values according to a TensorFlow.js weights manifest. |
| 37 | |
| 38 | Args: |
| 39 | weights_manifest: A TensorFlow.js-format weights manifest (a JSON array). |
| 40 | base_path: Base path prefix for the weights files. |
| 41 | flatten: Whether all the weight groups in the return value are to be |
| 42 | flattened as a single weights group. Default: `False`. |
| 43 | |
| 44 | Returns: |
| 45 | If `flatten` is `False`, a `list` of weight groups. Each group is an array |
| 46 | of weight entries. Each entry is a dict that maps a unique name to a numpy |
| 47 | array, for example: |
| 48 | entry = { |
| 49 | 'name': 'weight1', |
| 50 | 'data': np.array([1, 2, 3], 'float32') |
| 51 | } |
| 52 | |
| 53 | Weights groups would then look like: |
| 54 | weight_groups = [ |
| 55 | [group_0_entry1, group_0_entry2], |
| 56 | [group_1_entry1, group_1_entry2], |
| 57 | ] |
| 58 | If `flatten` is `True`, returns a single weight group. |
| 59 | """ |
| 60 | if not isinstance(weights_manifest, list): |
| 61 | raise ValueError( |
| 62 | 'weights_manifest should be a `list`, but received %s' % |
| 63 | type(weights_manifest)) |
| 64 | |
| 65 | data_buffers = [] |
| 66 | for group in weights_manifest: |
| 67 | buff = io.BytesIO() |
| 68 | buff_writer = io.BufferedWriter(buff) |
| 69 | for path in group['paths']: |
| 70 | with open(os.path.join(base_path, path), 'rb') as f: |
| 71 | buff_writer.write(f.read()) |
| 72 | buff_writer.flush() |
| 73 | buff_writer.seek(0) |
| 74 | data_buffers.append(buff.read()) |
| 75 | return decode_weights(weights_manifest, data_buffers, flatten=flatten) |
| 76 | |
| 77 | |
| 78 | def _deserialize_string_array(data_buffer, offset, shape): |
nothing calls this directly
no test coverage detected
searching dependent graphs…