* Perform classification on a batch of image tensors. * * @param {tf.Tensor} images Batch image tensor of shape * `[numExamples, height, width, channels]`. The values of `height`, * `width` and `channel` must match the underlying MobileNetV2 model * (default: 224, 224, 3). *
(images, topK = 5)
| 51 | * of `images`. |
| 52 | */ |
| 53 | async classify(images, topK = 5) { |
| 54 | await this.ensureModelLoaded(); |
| 55 | return tf.tidy(() => { |
| 56 | const probs = this.model.predict(images); |
| 57 | const sorted = true; |
| 58 | const {values, indices} = tf.topk(probs, topK, sorted); |
| 59 | |
| 60 | const classProbs = values.arraySync(); |
| 61 | const classIndices = indices.arraySync(); |
| 62 | |
| 63 | const results = []; |
| 64 | classIndices.forEach((indices, i) => { |
| 65 | const classesAndProbs = []; |
| 66 | indices.forEach((index, j) => { |
| 67 | classesAndProbs.push({ |
| 68 | className: IMAGENET_CLASSES[index], |
| 69 | prob: classProbs[i][j] |
| 70 | }); |
| 71 | }); |
| 72 | results.push(classesAndProbs); |
| 73 | }) |
| 74 | |
| 75 | return results; |
| 76 | }); |
| 77 | } |
| 78 | |
| 79 | /** |
| 80 | * If the underlying model is not loaded, load it. |
no test coverage detected