MCPcopy Index your code
hub / github.com/tensorflow/tfjs / save

Method save

tfjs-node/src/io/file_system.ts:80–127  ·  view source on GitHub ↗
(modelArtifacts: tf.io.ModelArtifacts)

Source from the content-addressed store, hash-verified

78 }
79
80 async save(modelArtifacts: tf.io.ModelArtifacts): Promise<tf.io.SaveResult> {
81 if (Array.isArray(this.path)) {
82 throw new Error('Cannot perform saving to multiple paths.');
83 }
84
85 await this.createOrVerifyDirectory();
86
87 if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
88 throw new Error(
89 'NodeFileSystem.save() does not support saving model topology ' +
90 'in binary format yet.');
91 // TODO(cais, nkreeger): Implement this. See
92 // https://github.com/tensorflow/tfjs/issues/343
93 } else {
94 const weightsBinPath = join(this.path, this.WEIGHTS_BINARY_FILENAME);
95 const weightsManifest = [{
96 paths: [this.WEIGHTS_BINARY_FILENAME],
97 weights: modelArtifacts.weightSpecs
98 }];
99 const modelJSON: tf.io.ModelJSON = {
100 modelTopology: modelArtifacts.modelTopology,
101 weightsManifest,
102 format: modelArtifacts.format,
103 generatedBy: modelArtifacts.generatedBy,
104 convertedBy: modelArtifacts.convertedBy
105 };
106 if (modelArtifacts.trainingConfig != null) {
107 modelJSON.trainingConfig = modelArtifacts.trainingConfig;
108 }
109 if (modelArtifacts.signature != null) {
110 modelJSON.signature = modelArtifacts.signature;
111 }
112 if (modelArtifacts.userDefinedMetadata != null) {
113 modelJSON.userDefinedMetadata = modelArtifacts.userDefinedMetadata;
114 }
115 const modelJSONPath = join(this.path, this.MODEL_JSON_FILENAME);
116 await writeFile(modelJSONPath, JSON.stringify(modelJSON), 'utf8');
117 await writeFile(
118 weightsBinPath, Buffer.from(modelArtifacts.weightData), 'binary');
119
120 return {
121 // TODO(cais): Use explicit tf.io.ModelArtifactsInfo type below once it
122 // is available.
123 // tslint:disable-next-line:no-any
124 modelArtifactsInfo: tf.io.getModelArtifactsInfoForJSON(modelArtifacts),
125 };
126 }
127 }
128 async load(): Promise<tf.io.ModelArtifacts> {
129 return Array.isArray(this.path) ? this.loadBinaryModel() :
130 this.loadJSONModel();

Callers 1

Calls 1

Tested by

no test coverage detected