MCPcopy Index your code
hub / github.com/tensorflow/tfjs-examples / showPredictions

Function showPredictions

mnist/index.js:225–253  ·  view source on GitHub ↗

* Show predictions on a number of test examples. * * @param {tf.Model} model The model to be used for making the predictions.

(model)

Source from the content-addressed store, hash-verified

223 * @param {tf.Model} model The model to be used for making the predictions.
224 */
225async function showPredictions(model) {
226 const testExamples = 100;
227 const examples = data.getTestData(testExamples);
228
229 // Code wrapped in a tf.tidy() function callback will have their tensors freed
230 // from GPU memory after execution without having to call dispose().
231 // The tf.tidy callback runs synchronously.
232 tf.tidy(() => {
233 const output = model.predict(examples.xs);
234
235 // tf.argMax() returns the indices of the maximum values in the tensor along
236 // a specific axis. Categorical classification tasks like this one often
237 // represent classes as one-hot vectors. One-hot vectors are 1D vectors with
238 // one element for each output class. All values in the vector are 0
239 // except for one, which has a value of 1 (e.g. [0, 0, 0, 1, 0]). The
240 // output from model.predict() will be a probability distribution, so we use
241 // argMax to get the index of the vector element that has the highest
242 // probability. This is our prediction.
243 // (e.g. argmax([0.07, 0.1, 0.03, 0.75, 0.05]) == 3)
244 // dataSync() synchronously downloads the tf.tensor values from the GPU so
245 // that we can use them in our normal CPU JavaScript code
246 // (for a non-blocking version of this function, use data()).
247 const axis = 1;
248 const labels = Array.from(examples.labels.argMax(axis).dataSync());
249 const predictions = Array.from(output.argMax(axis).dataSync());
250
251 ui.showTestResults(examples, predictions, labels);
252 });
253}
254
255function createModel() {
256 let model;

Callers 1

index.jsFile · 0.85

Calls 2

getTestDataMethod · 0.45
predictMethod · 0.45

Tested by

no test coverage detected