Load weight values from buffer(s) according to a weights manifest. Args: weights_manifest: A TensorFlow.js-format weights manifest (a JSON array). data_buffers: A buffer or a `list` of buffers containing the weights values in binary format, concatenated in the order specified in
(weights_manifest, data_buffers, flatten=False)
| 124 | offset=offset).reshape(shape) |
| 125 | |
| 126 | def decode_weights(weights_manifest, data_buffers, flatten=False): |
| 127 | """Load weight values from buffer(s) according to a weights manifest. |
| 128 | |
| 129 | Args: |
| 130 | weights_manifest: A TensorFlow.js-format weights manifest (a JSON array). |
| 131 | data_buffers: A buffer or a `list` of buffers containing the weights values |
| 132 | in binary format, concatenated in the order specified in |
| 133 | `weights_manifest`. If a `list` of buffers, the length of the `list` |
| 134 | must match the length of `weights_manifest`. A single buffer is |
| 135 | interpreted as a `list` of one buffer and is valid only if the length of |
| 136 | `weights_manifest` is `1`. |
| 137 | flatten: Whether all the weight groups in the return value are to be |
| 138 | flattened as a single weight groups. Default: `False`. |
| 139 | |
| 140 | Returns: |
| 141 | If `flatten` is `False`, a `list` of weight groups. Each group is an array |
| 142 | of weight entries. Each entry is a dict that maps a unique name to a numpy |
| 143 | array, for example: |
| 144 | entry = { |
| 145 | 'name': 'weight1', |
| 146 | 'data': np.array([1, 2, 3], 'float32') |
| 147 | } |
| 148 | |
| 149 | Weights groups would then look like: |
| 150 | weight_groups = [ |
| 151 | [group_0_entry1, group_0_entry2], |
| 152 | [group_1_entry1, group_1_entry2], |
| 153 | ] |
| 154 | If `flatten` is `True`, returns a single weight group. |
| 155 | |
| 156 | Raises: |
| 157 | ValueError: if the lengths of `weights_manifest` and `data_buffers` do not |
| 158 | match. |
| 159 | """ |
| 160 | if not isinstance(data_buffers, list): |
| 161 | data_buffers = [data_buffers] |
| 162 | if len(weights_manifest) != len(data_buffers): |
| 163 | raise ValueError( |
| 164 | 'Mismatch in the length of weights_manifest (%d) and the length of ' |
| 165 | 'data buffers (%d)' % (len(weights_manifest), len(data_buffers))) |
| 166 | |
| 167 | out = [] |
| 168 | for group, data_buffer in zip(weights_manifest, data_buffers): |
| 169 | offset = 0 |
| 170 | out_group = [] |
| 171 | |
| 172 | for weight in group['weights']: |
| 173 | quant_info = weight.get('quantization', None) |
| 174 | name = weight['name'] |
| 175 | if weight['dtype'] == 'string': |
| 176 | # String array. |
| 177 | dtype = object |
| 178 | elif quant_info: |
| 179 | # Quantized array. |
| 180 | dtype = np.dtype(quant_info['dtype']) |
| 181 | else: |
| 182 | # Regular numeric array. |
| 183 | dtype = np.dtype(weight['dtype']) |
no test coverage detected
searching dependent graphs…