| 664 | } |
| 665 | |
| 666 | protected parseWeights(weights: NamedTensorMap) { |
| 667 | for (const key in Object.keys(weights)) { |
| 668 | const listParts = key.split('/'); |
| 669 | const list = ['vars', 'layer_checkpoint_dependencies']; |
| 670 | // For keras v3, the weights name are saved based on the folder structure. |
| 671 | // e.g. _backbone/_layer_checkpoint_dependencies/transformer/_self../ |
| 672 | // _output_dense/vars/0 |
| 673 | // Therefore we discard the `vars` and `layer_checkpoint_depencies` within |
| 674 | // the saved name and only keeps the layer name and weights. |
| 675 | // This can help to mapping the actual name of the layers and load each |
| 676 | // weight accordingly. |
| 677 | const newKey = listParts |
| 678 | .map(str => { |
| 679 | if (str.startsWith('_')) { |
| 680 | return str.slice(1); |
| 681 | } |
| 682 | return str; |
| 683 | }) |
| 684 | .filter(str => !list.includes(str)) |
| 685 | .join('/'); |
| 686 | if (newKey !== key) { |
| 687 | weights[newKey] = weights[key]; |
| 688 | delete weights[key]; |
| 689 | } |
| 690 | } |
| 691 | } |
| 692 | |
| 693 | /** |
| 694 | * Util shared between different serialization methods. |