* Wrap the model's predict function (`model.predict` for tf.LayersModel * and `model.executeAsync` for tf.GraphModel) with the input. * * @param model An instance of tf.GraphModel or tf.LayersModel for finding and * wrapping the predict function. * @param input The input tensor container fo
(model, input)
| 150 | * @param input The input tensor container for model inference. |
| 151 | */ |
| 152 | function getPredictFnForModel(model, input) { |
| 153 | let predict; |
| 154 | if (model instanceof tf.GraphModel) { |
| 155 | // Because there's no straightforward way to analyze whether a graph has |
| 156 | // dynamic op, so we try to use `execute` and, if it fails, we will fall |
| 157 | // back to `executeAsync`. |
| 158 | try { |
| 159 | tf.tidy(() => { |
| 160 | model.execute(input); |
| 161 | }); |
| 162 | predict = () => model.execute(input); |
| 163 | } catch (e) { |
| 164 | predict = async () => await model.executeAsync(input); |
| 165 | } |
| 166 | } else if (model instanceof tf.LayersModel) { |
| 167 | predict = () => model.predict(input); |
| 168 | } else { |
| 169 | throw new Error( |
| 170 | 'Predict function was not found. Please provide a tf.GraphModel or ' + |
| 171 | 'tf.LayersModel'); |
| 172 | } |
| 173 | return predict; |
| 174 | } |
| 175 | |
| 176 | /** |
| 177 | * Executes the predict function for `model` (`model.predict` for tf.LayersModel |
no test coverage detected
searching dependent graphs…