| 187 | } |
| 188 | |
| 189 | private setModelInputFromTensor( |
| 190 | modelInput: TFLiteWebModelRunnerTensorInfo, tensor: Tensor) { |
| 191 | // String and complex tensors are not supported. |
| 192 | if (tensor.dtype === 'string' || tensor.dtype === 'complex64') { |
| 193 | throw new Error(`Data type '${tensor.dtype}' not supported.`); |
| 194 | } |
| 195 | |
| 196 | // Check shape. |
| 197 | // |
| 198 | // At this point, we've already checked that input tensors and model inputs |
| 199 | // have the same size. |
| 200 | const modelInputShape = modelInput.shape.split(',').map(dim => Number(dim)); |
| 201 | if (!tensor.shape.every( |
| 202 | (dim, index) => modelInputShape[index] === -1 || |
| 203 | modelInputShape[index] === dim)) { |
| 204 | throw new Error(`Input tensor shape mismatch: expect '${ |
| 205 | modelInput.shape}', got '${tensor.shape.join(',')}'.`); |
| 206 | } |
| 207 | |
| 208 | // Check types. |
| 209 | switch (modelInput.dataType) { |
| 210 | // All 'bool' and 'int' tflite types accpet 'bool' or 'int32' tfjs types. |
| 211 | // Will throw error for 'float32' tfjs type. |
| 212 | case 'bool': |
| 213 | case 'int8': |
| 214 | case 'uint8': |
| 215 | case 'int16': |
| 216 | case 'uint32': |
| 217 | case 'int32': |
| 218 | if (tensor.dtype === 'float32') { |
| 219 | throw this.getDataTypeMismatchError( |
| 220 | modelInput.dataType, tensor.dtype); |
| 221 | } else if (modelInput.dataType !== tensor.dtype) { |
| 222 | console.warn(`WARNING: converting '${tensor.dtype}' to '${ |
| 223 | modelInput.dataType}'`); |
| 224 | } |
| 225 | break; |
| 226 | // All 'float' tflite types accept all tfjs types. |
| 227 | case 'float32': |
| 228 | case 'float64': |
| 229 | if (modelInput.dataType !== tensor.dtype) { |
| 230 | console.warn(`WARNING: converting '${tensor.dtype}' to '${ |
| 231 | modelInput.dataType}'`); |
| 232 | } |
| 233 | break; |
| 234 | default: |
| 235 | break; |
| 236 | } |
| 237 | |
| 238 | const modelInputBuffer = modelInput.data(); |
| 239 | switch (modelInput.dataType) { |
| 240 | case 'int8': |
| 241 | modelInputBuffer.set(Int8Array.from(tensor.dataSync())); |
| 242 | break; |
| 243 | case 'uint8': |
| 244 | case 'bool': |
| 245 | modelInputBuffer.set(Uint8Array.from(tensor.dataSync())); |
| 246 | break; |