* Predicts intermediate Tensor representations. * * @param input The input RGB image of the base model. * A Tensor of shape: [`inputResolution`, `inputResolution`, 3]. * * @return A dictionary of base model's intermediate predictions. * The returned dictionary should contains the f
(input: tf.Tensor3D)
| 56 | * displacementBwd: A Tensor3D that represents the backward displacement. |
| 57 | */ |
| 58 | predict(input: tf.Tensor3D): { |
| 59 | heatmapScores: tf.Tensor3D, |
| 60 | offsets: tf.Tensor3D, |
| 61 | displacementFwd: tf.Tensor3D, |
| 62 | displacementBwd: tf.Tensor3D |
| 63 | } { |
| 64 | return tf.tidy(() => { |
| 65 | const asFloat = this.preprocessInput(tf.cast(input, 'float32')); |
| 66 | const asBatch = tf.expandDims(asFloat, 0); |
| 67 | const results = this.model.predict(asBatch) as tf.Tensor4D[]; |
| 68 | const results3d: tf.Tensor3D[] = results.map(y => tf.squeeze(y, [0])); |
| 69 | |
| 70 | const namedResults = this.nameOutputResults(results3d); |
| 71 | |
| 72 | return { |
| 73 | heatmapScores: tf.sigmoid(namedResults.heatmap), |
| 74 | offsets: namedResults.offsets, |
| 75 | displacementFwd: namedResults.displacementFwd, |
| 76 | displacementBwd: namedResults.displacementBwd |
| 77 | }; |
| 78 | }); |
| 79 | } |
| 80 | |
| 81 | // Because MobileNet and ResNet predict() methods output a different order for |
| 82 | // these values, we have a method that needs to be implemented to order them. |
nothing calls this directly
no test coverage detected