MCPcopy Index your code
hub / github.com/tensorflow/tfjs / read_weights

Function read_weights

tfjs-converter/python/tensorflowjs/read_weights.py:35–75  ·  view source on GitHub ↗

Load weight values according to a TensorFlow.js weights manifest. Args: weights_manifest: A TensorFlow.js-format weights manifest (a JSON array). base_path: Base path prefix for the weights files. flatten: Whether all the weight groups in the return value are to be flattened as

(weights_manifest, base_path, flatten=False)

Source from the content-addressed store, hash-verified

33STRING_LENGTH_DTYPE = np.dtype('uint32').newbyteorder('<')
34
35def read_weights(weights_manifest, base_path, flatten=False):
36 """Load weight values according to a TensorFlow.js weights manifest.
37
38 Args:
39 weights_manifest: A TensorFlow.js-format weights manifest (a JSON array).
40 base_path: Base path prefix for the weights files.
41 flatten: Whether all the weight groups in the return value are to be
42 flattened as a single weights group. Default: `False`.
43
44 Returns:
45 If `flatten` is `False`, a `list` of weight groups. Each group is an array
46 of weight entries. Each entry is a dict that maps a unique name to a numpy
47 array, for example:
48 entry = {
49 'name': 'weight1',
50 'data': np.array([1, 2, 3], 'float32')
51 }
52
53 Weights groups would then look like:
54 weight_groups = [
55 [group_0_entry1, group_0_entry2],
56 [group_1_entry1, group_1_entry2],
57 ]
58 If `flatten` is `True`, returns a single weight group.
59 """
60 if not isinstance(weights_manifest, list):
61 raise ValueError(
62 'weights_manifest should be a `list`, but received %s' %
63 type(weights_manifest))
64
65 data_buffers = []
66 for group in weights_manifest:
67 buff = io.BytesIO()
68 buff_writer = io.BufferedWriter(buff)
69 for path in group['paths']:
70 with open(os.path.join(base_path, path), 'rb') as f:
71 buff_writer.write(f.read())
72 buff_writer.flush()
73 buff_writer.seek(0)
74 data_buffers.append(buff.read())
75 return decode_weights(weights_manifest, data_buffers, flatten=flatten)
76
77
78def _deserialize_string_array(data_buffer, offset, shape):

Callers

nothing calls this directly

Calls 7

ValueErrorClass · 0.85
decode_weightsFunction · 0.85
joinMethod · 0.80
flushMethod · 0.80
appendMethod · 0.80
writeMethod · 0.65
readMethod · 0.65

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…