()
| 71 | } |
| 72 | |
| 73 | async function main() { |
| 74 | const args = parseArgs(); |
| 75 | if (args.gpu) { |
| 76 | tf = require('@tensorflow/tfjs-node-gpu'); |
| 77 | } else { |
| 78 | tf = require('@tensorflow/tfjs-node'); |
| 79 | } |
| 80 | |
| 81 | console.log(`Loading model from ${args.modelSavePath}...`); |
| 82 | const model = await tf.loadLayersModel(`file://${args.modelSavePath}`); |
| 83 | |
| 84 | const imageH = model.inputs[0].shape[2]; |
| 85 | const imageW = model.inputs[0].shape[2]; |
| 86 | |
| 87 | // Load the images into a tensor. |
| 88 | const dirContent = fs.readdirSync(args.imageDir); |
| 89 | dirContent.sort(); |
| 90 | const numImages = dirContent.length; |
| 91 | console.log(`Reading ${numImages} images...`); |
| 92 | const progressBar = new ProgressBar('[:bar]', { |
| 93 | total: numImages, |
| 94 | width: 80, |
| 95 | head: '>' |
| 96 | }); |
| 97 | const imageTensors = []; |
| 98 | const truthLabels = []; |
| 99 | for (const fileName of dirContent) { |
| 100 | const truthLabel = fileName.split('.')[0].split('_')[2]; |
| 101 | truthLabels.push(truthLabel); |
| 102 | const imageFilePath = path.join(args.imageDir, fileName); |
| 103 | const imageTensor = |
| 104 | await readImageTensorFromFile(imageFilePath, imageH, imageW); |
| 105 | imageTensors.push(imageTensor); |
| 106 | progressBar.tick(); |
| 107 | } |
| 108 | |
| 109 | const stackedImageTensor = tf.concat(imageTensors, 0); |
| 110 | console.log('Calling model.predict()...'); |
| 111 | const t0 = new Date().getTime(); |
| 112 | const {top1Indices, top5Indices} = tf.tidy(() => { |
| 113 | const probs = model.predict(stackedImageTensor, {batchSize: 64}); |
| 114 | return { |
| 115 | top1Indices: probs.argMax(-1).arraySync(), |
| 116 | top5Indices: probs.topk(5).indices.arraySync() |
| 117 | }; |
| 118 | }); |
| 119 | console.log(`model.predict() took ${(new Date().getTime() - t0).toFixed(2)} ms`); |
| 120 | |
| 121 | let numCorrectTop1 = 0; |
| 122 | let numCorrectTop5 = 0; |
| 123 | top1Indices.forEach((top1Index, i) => { |
| 124 | const truthLabel = truthLabels[i]; |
| 125 | if (IMAGENET_CLASSES[top1Index].indexOf(truthLabel) !== -1) { |
| 126 | numCorrectTop1++; |
| 127 | } |
| 128 | for (let k = 0; k < 5; ++k) { |
| 129 | if (IMAGENET_CLASSES[top5Indices[i][k]].indexOf(truthLabel) !== -1) { |
| 130 | numCorrectTop5++; |
no test coverage detected