MCPcopy
hub / github.com/tensorflow/tfjs / setModelInputFromTensor

Method setModelInputFromTensor

tfjs-tflite/src/tflite_model.ts:189–265  ·  view source on GitHub ↗
(
      modelInput: TFLiteWebModelRunnerTensorInfo, tensor: Tensor)

Source from the content-addressed store, hash-verified

187 }
188
189 private setModelInputFromTensor(
190 modelInput: TFLiteWebModelRunnerTensorInfo, tensor: Tensor) {
191 // String and complex tensors are not supported.
192 if (tensor.dtype === 'string' || tensor.dtype === 'complex64') {
193 throw new Error(`Data type '${tensor.dtype}' not supported.`);
194 }
195
196 // Check shape.
197 //
198 // At this point, we've already checked that input tensors and model inputs
199 // have the same size.
200 const modelInputShape = modelInput.shape.split(',').map(dim => Number(dim));
201 if (!tensor.shape.every(
202 (dim, index) => modelInputShape[index] === -1 ||
203 modelInputShape[index] === dim)) {
204 throw new Error(`Input tensor shape mismatch: expect '${
205 modelInput.shape}', got '${tensor.shape.join(',')}'.`);
206 }
207
208 // Check types.
209 switch (modelInput.dataType) {
210 // All 'bool' and 'int' tflite types accpet 'bool' or 'int32' tfjs types.
211 // Will throw error for 'float32' tfjs type.
212 case 'bool':
213 case 'int8':
214 case 'uint8':
215 case 'int16':
216 case 'uint32':
217 case 'int32':
218 if (tensor.dtype === 'float32') {
219 throw this.getDataTypeMismatchError(
220 modelInput.dataType, tensor.dtype);
221 } else if (modelInput.dataType !== tensor.dtype) {
222 console.warn(`WARNING: converting '${tensor.dtype}' to '${
223 modelInput.dataType}'`);
224 }
225 break;
226 // All 'float' tflite types accept all tfjs types.
227 case 'float32':
228 case 'float64':
229 if (modelInput.dataType !== tensor.dtype) {
230 console.warn(`WARNING: converting '${tensor.dtype}' to '${
231 modelInput.dataType}'`);
232 }
233 break;
234 default:
235 break;
236 }
237
238 const modelInputBuffer = modelInput.data();
239 switch (modelInput.dataType) {
240 case 'int8':
241 modelInputBuffer.set(Int8Array.from(tensor.dataSync()));
242 break;
243 case 'uint8':
244 case 'bool':
245 modelInputBuffer.set(Uint8Array.from(tensor.dataSync()));
246 break;

Callers 1

predictMethod · 0.95

Calls 6

joinMethod · 0.80
splitMethod · 0.65
dataMethod · 0.65
dataSyncMethod · 0.65
setMethod · 0.45

Tested by

no test coverage detected