MCPcopy
hub / github.com/tensorflow/tfjs-examples / main

Function main

quantization/eval_mobilenetv2.js:73–141  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

71}
72
73async function main() {
74 const args = parseArgs();
75 if (args.gpu) {
76 tf = require('@tensorflow/tfjs-node-gpu');
77 } else {
78 tf = require('@tensorflow/tfjs-node');
79 }
80
81 console.log(`Loading model from ${args.modelSavePath}...`);
82 const model = await tf.loadLayersModel(`file://${args.modelSavePath}`);
83
84 const imageH = model.inputs[0].shape[2];
85 const imageW = model.inputs[0].shape[2];
86
87 // Load the images into a tensor.
88 const dirContent = fs.readdirSync(args.imageDir);
89 dirContent.sort();
90 const numImages = dirContent.length;
91 console.log(`Reading ${numImages} images...`);
92 const progressBar = new ProgressBar('[:bar]', {
93 total: numImages,
94 width: 80,
95 head: '>'
96 });
97 const imageTensors = [];
98 const truthLabels = [];
99 for (const fileName of dirContent) {
100 const truthLabel = fileName.split('.')[0].split('_')[2];
101 truthLabels.push(truthLabel);
102 const imageFilePath = path.join(args.imageDir, fileName);
103 const imageTensor =
104 await readImageTensorFromFile(imageFilePath, imageH, imageW);
105 imageTensors.push(imageTensor);
106 progressBar.tick();
107 }
108
109 const stackedImageTensor = tf.concat(imageTensors, 0);
110 console.log('Calling model.predict()...');
111 const t0 = new Date().getTime();
112 const {top1Indices, top5Indices} = tf.tidy(() => {
113 const probs = model.predict(stackedImageTensor, {batchSize: 64});
114 return {
115 top1Indices: probs.argMax(-1).arraySync(),
116 top5Indices: probs.topk(5).indices.arraySync()
117 };
118 });
119 console.log(`model.predict() took ${(new Date().getTime() - t0).toFixed(2)} ms`);
120
121 let numCorrectTop1 = 0;
122 let numCorrectTop5 = 0;
123 top1Indices.forEach((top1Index, i) => {
124 const truthLabel = truthLabels[i];
125 if (IMAGENET_CLASSES[top1Index].indexOf(truthLabel) !== -1) {
126 numCorrectTop1++;
127 }
128 for (let k = 0; k < 5; ++k) {
129 if (IMAGENET_CLASSES[top5Indices[i][k]].indexOf(truthLabel) !== -1) {
130 numCorrectTop5++;

Callers 1

Calls 4

getTimeMethod · 0.80
parseArgsFunction · 0.70
readImageTensorFromFileFunction · 0.70
predictMethod · 0.45

Tested by

no test coverage detected