MCPcopy
hub / github.com/tensorflow/tfjs / profileInference

Function profileInference

e2e/benchmarks/benchmark_util.js:465–510  ·  view source on GitHub ↗

* Executes `predict()` and returns a promise that resolves with information * about the memory usage: * - `newBytes`: the number of new bytes allocated. * - `newTensors`: the number of new tensors created. * - `peakBytes`: the peak number of bytes allocated. * - `kernels`: an array of kernel in

(predict, isTflite = false, numProfiles = 1)

Source from the content-addressed store, hash-verified

463 * @param numProfiles The number of rounds for `predict` to execute and profile.
464 */
465async function profileInference(predict, isTflite = false, numProfiles = 1) {
466 if (typeof predict !== 'function') {
467 throw new Error(
468 'The first parameter should be a function, while ' +
469 `a(n) ${typeof predict} is found.`);
470 }
471
472 let kernelInfo = {};
473 let kernelInfos = [];
474 if (isTflite) {
475 for (let i = 0; i < numProfiles; i++) {
476 await predict();
477 const profileItems = await tfliteModel.getProfilingResults();
478 kernelInfo.kernels = profileItems.map(item => {
479 return {
480 name: item.nodeType,
481 kernelTimeMs: item.nodeExecMs,
482 // TODO: Shapes are not supported yet.
483 inputShapes: [],
484 outputShapes: [],
485 };
486 });
487 kernelInfos.push(kernelInfo);
488 }
489 } else {
490 for (let i = 0; i < numProfiles; i++) {
491 kernelInfo = await tf.profile(async () => {
492 const res = await predict();
493 await downloadValuesFromTensorContainer(res);
494 tf.dispose(res);
495 });
496 kernelInfos.push(kernelInfo);
497 }
498 }
499 for (let i = 0; i < kernelInfos[0].kernels.length; i++) {
500 let totalTimeMs = 0;
501 for (let j = 0; j < kernelInfos.length; j++) {
502 totalTimeMs += kernelInfos[j].kernels[i].kernelTimeMs;
503 }
504 kernelInfo.kernels[i].kernelTimeMs = totalTimeMs / kernelInfos.length;
505 }
506 kernelInfo.kernels =
507 kernelInfo.kernels.sort((a, b) => b.kernelTimeMs - a.kernelTimeMs);
508 kernelInfo.aggregatedKernels = aggregateKernelTime(kernelInfo.kernels);
509 return kernelInfo;
510}
511
512/**
513 * Aggregate kernels by name and sort the array in non-ascending order of time.

Callers 4

profileModelInferenceFunction · 0.85
benchmarkModelFunction · 0.85
benchmarkCodeSnippetFunction · 0.85

Calls 7

aggregateKernelTimeFunction · 0.85
profileMethod · 0.80
predictFunction · 0.70
getProfilingResultsMethod · 0.65
pushMethod · 0.45
disposeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…