MCPcopy Index your code
hub / github.com/tensorflow/tfjs / decode_weights

Function decode_weights

tfjs-converter/python/tensorflowjs/read_weights.py:126–202  ·  view source on GitHub ↗

Load weight values from buffer(s) according to a weights manifest. Args: weights_manifest: A TensorFlow.js-format weights manifest (a JSON array). data_buffers: A buffer or a `list` of buffers containing the weights values in binary format, concatenated in the order specified in

(weights_manifest, data_buffers, flatten=False)

Source from the content-addressed store, hash-verified

124 offset=offset).reshape(shape)
125
126def decode_weights(weights_manifest, data_buffers, flatten=False):
127 """Load weight values from buffer(s) according to a weights manifest.
128
129 Args:
130 weights_manifest: A TensorFlow.js-format weights manifest (a JSON array).
131 data_buffers: A buffer or a `list` of buffers containing the weights values
132 in binary format, concatenated in the order specified in
133 `weights_manifest`. If a `list` of buffers, the length of the `list`
134 must match the length of `weights_manifest`. A single buffer is
135 interpreted as a `list` of one buffer and is valid only if the length of
136 `weights_manifest` is `1`.
137 flatten: Whether all the weight groups in the return value are to be
138 flattened as a single weight groups. Default: `False`.
139
140 Returns:
141 If `flatten` is `False`, a `list` of weight groups. Each group is an array
142 of weight entries. Each entry is a dict that maps a unique name to a numpy
143 array, for example:
144 entry = {
145 'name': 'weight1',
146 'data': np.array([1, 2, 3], 'float32')
147 }
148
149 Weights groups would then look like:
150 weight_groups = [
151 [group_0_entry1, group_0_entry2],
152 [group_1_entry1, group_1_entry2],
153 ]
154 If `flatten` is `True`, returns a single weight group.
155
156 Raises:
157 ValueError: if the lengths of `weights_manifest` and `data_buffers` do not
158 match.
159 """
160 if not isinstance(data_buffers, list):
161 data_buffers = [data_buffers]
162 if len(weights_manifest) != len(data_buffers):
163 raise ValueError(
164 'Mismatch in the length of weights_manifest (%d) and the length of '
165 'data buffers (%d)' % (len(weights_manifest), len(data_buffers)))
166
167 out = []
168 for group, data_buffer in zip(weights_manifest, data_buffers):
169 offset = 0
170 out_group = []
171
172 for weight in group['weights']:
173 quant_info = weight.get('quantization', None)
174 name = weight['name']
175 if weight['dtype'] == 'string':
176 # String array.
177 dtype = object
178 elif quant_info:
179 # Quantized array.
180 dtype = np.dtype(quant_info['dtype'])
181 else:
182 # Regular numeric array.
183 dtype = np.dtype(weight['dtype'])

Callers 1

read_weightsFunction · 0.85

Calls 7

ValueErrorClass · 0.85
zipFunction · 0.85
NotImplementedErrorClass · 0.85
appendMethod · 0.80
getMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…