* Save the configuration and/or weights of the LayersModel. * * An `IOHandler` is an object that has a `save` method of the proper * signature defined. The `save` method manages the storing or * transmission of serialized data ("artifacts") that represent the * model's topology and we
(handlerOrURL: io.IOHandler|string, config?: io.SaveConfig)
| 2110 | * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true} |
| 2111 | */ |
| 2112 | async save(handlerOrURL: io.IOHandler|string, config?: io.SaveConfig): |
| 2113 | Promise<io.SaveResult> { |
| 2114 | if (typeof handlerOrURL === 'string') { |
| 2115 | const handlers = io.getSaveHandlers(handlerOrURL); |
| 2116 | if (handlers.length === 0) { |
| 2117 | throw new ValueError( |
| 2118 | `Cannot find any save handlers for URL '${handlerOrURL}'`); |
| 2119 | } else if (handlers.length > 1) { |
| 2120 | throw new ValueError( |
| 2121 | `Found more than one (${handlers.length}) save handlers for ` + |
| 2122 | `URL '${handlerOrURL}'`); |
| 2123 | } |
| 2124 | handlerOrURL = handlers[0]; |
| 2125 | } |
| 2126 | if (handlerOrURL.save == null) { |
| 2127 | throw new ValueError( |
| 2128 | 'LayersModel.save() cannot proceed because the IOHandler ' + |
| 2129 | 'provided does not have the `save` attribute defined.'); |
| 2130 | } |
| 2131 | |
| 2132 | const weightDataAndSpecs = |
| 2133 | await io.encodeWeights(this.getNamedWeights(config)); |
| 2134 | |
| 2135 | const returnString = false; |
| 2136 | const unusedArg: {} = null; |
| 2137 | const modelConfig = this.toJSON(unusedArg, returnString); |
| 2138 | const modelArtifacts: io.ModelArtifacts = { |
| 2139 | modelTopology: modelConfig, |
| 2140 | format: LAYERS_MODEL_FORMAT_NAME, |
| 2141 | generatedBy: `TensorFlow.js tfjs-layers v${version}`, |
| 2142 | convertedBy: null, |
| 2143 | }; |
| 2144 | |
| 2145 | const includeOptimizer = config == null ? false : config.includeOptimizer; |
| 2146 | if (includeOptimizer && this.optimizer != null) { |
| 2147 | modelArtifacts.trainingConfig = this.getTrainingConfig(); |
| 2148 | const weightType = 'optimizer'; |
| 2149 | const {data: optimizerWeightData, specs: optimizerWeightSpecs} = |
| 2150 | await io.encodeWeights(await this.optimizer.getWeights(), weightType); |
| 2151 | weightDataAndSpecs.specs.push(...optimizerWeightSpecs); |
| 2152 | weightDataAndSpecs.data = io.concatenateArrayBuffers( |
| 2153 | [weightDataAndSpecs.data, optimizerWeightData]); |
| 2154 | } |
| 2155 | |
| 2156 | if (this.userDefinedMetadata != null) { |
| 2157 | // Check serialized size of user-defined metadata. |
| 2158 | const checkSize = true; |
| 2159 | checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize); |
| 2160 | modelArtifacts.userDefinedMetadata = this.userDefinedMetadata; |
| 2161 | } |
| 2162 | |
| 2163 | modelArtifacts.weightData = weightDataAndSpecs.data; |
| 2164 | modelArtifacts.weightSpecs = weightDataAndSpecs.specs; |
| 2165 | return handlerOrURL.save(modelArtifacts); |
| 2166 | } |
| 2167 | |
| 2168 | /** |
| 2169 | * Set user-defined metadata. |
no test coverage detected