* Downloads the values from the `tensorContainer` from any `tf.Tensor`s found * within the `tensorContainer`. Returns a promise of `TypedArray` or * `TypedArray[]` that resolves when the computation has finished. * * The values are asynchronously downloaded in parallel. * * @param tensorContai
(tensorContainer)
| 352 | * @param tensorContainer The container of tensors to be downloaded. |
| 353 | */ |
| 354 | async function downloadValuesFromTensorContainer(tensorContainer) { |
| 355 | let valueContainer; |
| 356 | const readSync = tf.getBackend() === 'webgl'; |
| 357 | if (tensorContainer instanceof tf.Tensor) { |
| 358 | if (readSync) { |
| 359 | valueContainer = tensorContainer.dataSync(); |
| 360 | } else { |
| 361 | valueContainer = await tensorContainer.data(); |
| 362 | } |
| 363 | } else if (Array.isArray(tensorContainer)) { |
| 364 | // Start value downloads from all tensors. |
| 365 | const valuePromiseContainer = tensorContainer.map(async item => { |
| 366 | if (item instanceof tf.Tensor) { |
| 367 | if (readSync) { |
| 368 | return item.dataSync(); |
| 369 | } else { |
| 370 | return item.data(); |
| 371 | } |
| 372 | } |
| 373 | return item; |
| 374 | }); |
| 375 | // Wait until all values are downloaded. |
| 376 | valueContainer = await Promise.all(valuePromiseContainer); |
| 377 | } else if (tensorContainer != null && typeof tensorContainer === 'object') { |
| 378 | const valuePromiseContainer = []; |
| 379 | // Start value downloads from all tensors. |
| 380 | for (const property in tensorContainer) { |
| 381 | if (tensorContainer[property] instanceof tf.Tensor) { |
| 382 | if (readSync) { |
| 383 | valuePromiseContainer.push(tensorContainer[property].dataSync()); |
| 384 | } else { |
| 385 | valuePromiseContainer.push(tensorContainer[property].data()); |
| 386 | } |
| 387 | } else { |
| 388 | valuePromiseContainer.push(tensorContainer[property]); |
| 389 | } |
| 390 | } |
| 391 | // Wait until all values are downloaded. |
| 392 | valueContainer = await Promise.all(valuePromiseContainer); |
| 393 | } |
| 394 | return valueContainer; |
| 395 | } |
| 396 | |
| 397 | /** |
| 398 | * Executes the predict function for `model` (`model.predict` for |
no test coverage detected
searching dependent graphs…