(modelArtifacts: tf.io.ModelArtifacts)
| 78 | } |
| 79 | |
| 80 | async save(modelArtifacts: tf.io.ModelArtifacts): Promise<tf.io.SaveResult> { |
| 81 | if (Array.isArray(this.path)) { |
| 82 | throw new Error('Cannot perform saving to multiple paths.'); |
| 83 | } |
| 84 | |
| 85 | await this.createOrVerifyDirectory(); |
| 86 | |
| 87 | if (modelArtifacts.modelTopology instanceof ArrayBuffer) { |
| 88 | throw new Error( |
| 89 | 'NodeFileSystem.save() does not support saving model topology ' + |
| 90 | 'in binary format yet.'); |
| 91 | // TODO(cais, nkreeger): Implement this. See |
| 92 | // https://github.com/tensorflow/tfjs/issues/343 |
| 93 | } else { |
| 94 | const weightsBinPath = join(this.path, this.WEIGHTS_BINARY_FILENAME); |
| 95 | const weightsManifest = [{ |
| 96 | paths: [this.WEIGHTS_BINARY_FILENAME], |
| 97 | weights: modelArtifacts.weightSpecs |
| 98 | }]; |
| 99 | const modelJSON: tf.io.ModelJSON = { |
| 100 | modelTopology: modelArtifacts.modelTopology, |
| 101 | weightsManifest, |
| 102 | format: modelArtifacts.format, |
| 103 | generatedBy: modelArtifacts.generatedBy, |
| 104 | convertedBy: modelArtifacts.convertedBy |
| 105 | }; |
| 106 | if (modelArtifacts.trainingConfig != null) { |
| 107 | modelJSON.trainingConfig = modelArtifacts.trainingConfig; |
| 108 | } |
| 109 | if (modelArtifacts.signature != null) { |
| 110 | modelJSON.signature = modelArtifacts.signature; |
| 111 | } |
| 112 | if (modelArtifacts.userDefinedMetadata != null) { |
| 113 | modelJSON.userDefinedMetadata = modelArtifacts.userDefinedMetadata; |
| 114 | } |
| 115 | const modelJSONPath = join(this.path, this.MODEL_JSON_FILENAME); |
| 116 | await writeFile(modelJSONPath, JSON.stringify(modelJSON), 'utf8'); |
| 117 | await writeFile( |
| 118 | weightsBinPath, Buffer.from(modelArtifacts.weightData), 'binary'); |
| 119 | |
| 120 | return { |
| 121 | // TODO(cais): Use explicit tf.io.ModelArtifactsInfo type below once it |
| 122 | // is available. |
| 123 | // tslint:disable-next-line:no-any |
| 124 | modelArtifactsInfo: tf.io.getModelArtifactsInfoForJSON(modelArtifacts), |
| 125 | }; |
| 126 | } |
| 127 | } |
| 128 | async load(): Promise<tf.io.ModelArtifacts> { |
| 129 | return Array.isArray(this.path) ? this.loadBinaryModel() : |
| 130 | this.loadJSONModel(); |
no test coverage detected