* Executes `predict()` and returns a promise that resolves with information * about the memory usage: * - `newBytes`: the number of new bytes allocated. * - `newTensors`: the number of new tensors created. * - `peakBytes`: the peak number of bytes allocated. * - `kernels`: an array of kernel in
(predict, isTflite = false, numProfiles = 1)
| 463 | * @param numProfiles The number of rounds for `predict` to execute and profile. |
| 464 | */ |
| 465 | async function profileInference(predict, isTflite = false, numProfiles = 1) { |
| 466 | if (typeof predict !== 'function') { |
| 467 | throw new Error( |
| 468 | 'The first parameter should be a function, while ' + |
| 469 | `a(n) ${typeof predict} is found.`); |
| 470 | } |
| 471 | |
| 472 | let kernelInfo = {}; |
| 473 | let kernelInfos = []; |
| 474 | if (isTflite) { |
| 475 | for (let i = 0; i < numProfiles; i++) { |
| 476 | await predict(); |
| 477 | const profileItems = await tfliteModel.getProfilingResults(); |
| 478 | kernelInfo.kernels = profileItems.map(item => { |
| 479 | return { |
| 480 | name: item.nodeType, |
| 481 | kernelTimeMs: item.nodeExecMs, |
| 482 | // TODO: Shapes are not supported yet. |
| 483 | inputShapes: [], |
| 484 | outputShapes: [], |
| 485 | }; |
| 486 | }); |
| 487 | kernelInfos.push(kernelInfo); |
| 488 | } |
| 489 | } else { |
| 490 | for (let i = 0; i < numProfiles; i++) { |
| 491 | kernelInfo = await tf.profile(async () => { |
| 492 | const res = await predict(); |
| 493 | await downloadValuesFromTensorContainer(res); |
| 494 | tf.dispose(res); |
| 495 | }); |
| 496 | kernelInfos.push(kernelInfo); |
| 497 | } |
| 498 | } |
| 499 | for (let i = 0; i < kernelInfos[0].kernels.length; i++) { |
| 500 | let totalTimeMs = 0; |
| 501 | for (let j = 0; j < kernelInfos.length; j++) { |
| 502 | totalTimeMs += kernelInfos[j].kernels[i].kernelTimeMs; |
| 503 | } |
| 504 | kernelInfo.kernels[i].kernelTimeMs = totalTimeMs / kernelInfos.length; |
| 505 | } |
| 506 | kernelInfo.kernels = |
| 507 | kernelInfo.kernels.sort((a, b) => b.kernelTimeMs - a.kernelTimeMs); |
| 508 | kernelInfo.aggregatedKernels = aggregateKernelTime(kernelInfo.kernels); |
| 509 | return kernelInfo; |
| 510 | } |
| 511 | |
| 512 | /** |
| 513 | * Aggregate kernels by name and sort the array in non-ascending order of time. |
no test coverage detected
searching dependent graphs…