MCPcopy Index your code
hub / github.com/tensorflow/tfjs / parseWeights

Function parseWeights

tfjs-layers/src/engine/container.ts:666–691  ·  view source on GitHub ↗
(weights: NamedTensorMap)

Source from the content-addressed store, hash-verified

664 }
665
666 protected parseWeights(weights: NamedTensorMap) {
667 for (const key in Object.keys(weights)) {
668 const listParts = key.split('/');
669 const list = ['vars', 'layer_checkpoint_dependencies'];
670 // For keras v3, the weights name are saved based on the folder structure.
671 // e.g. _backbone/_layer_checkpoint_dependencies/transformer/_self../
672 // _output_dense/vars/0
673 // Therefore we discard the `vars` and `layer_checkpoint_depencies` within
674 // the saved name and only keeps the layer name and weights.
675 // This can help to mapping the actual name of the layers and load each
676 // weight accordingly.
677 const newKey = listParts
678 .map(str => {
679 if (str.startsWith('_')) {
680 return str.slice(1);
681 }
682 return str;
683 })
684 .filter(str => !list.includes(str))
685 .join('/');
686 if (newKey !== key) {
687 weights[newKey] = weights[key];
688 delete weights[key];
689 }
690 }
691 }
692
693 /**
694 * Util shared between different serialization methods.

Callers

nothing calls this directly

Calls 3

joinMethod · 0.80
splitMethod · 0.65
sliceMethod · 0.65

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…