* Show predictions on a number of test examples. * * @param {tf.Model} model The model to be used for making the predictions.
(model)
| 223 | * @param {tf.Model} model The model to be used for making the predictions. |
| 224 | */ |
| 225 | async function showPredictions(model) { |
| 226 | const testExamples = 100; |
| 227 | const examples = data.getTestData(testExamples); |
| 228 | |
| 229 | // Code wrapped in a tf.tidy() function callback will have their tensors freed |
| 230 | // from GPU memory after execution without having to call dispose(). |
| 231 | // The tf.tidy callback runs synchronously. |
| 232 | tf.tidy(() => { |
| 233 | const output = model.predict(examples.xs); |
| 234 | |
| 235 | // tf.argMax() returns the indices of the maximum values in the tensor along |
| 236 | // a specific axis. Categorical classification tasks like this one often |
| 237 | // represent classes as one-hot vectors. One-hot vectors are 1D vectors with |
| 238 | // one element for each output class. All values in the vector are 0 |
| 239 | // except for one, which has a value of 1 (e.g. [0, 0, 0, 1, 0]). The |
| 240 | // output from model.predict() will be a probability distribution, so we use |
| 241 | // argMax to get the index of the vector element that has the highest |
| 242 | // probability. This is our prediction. |
| 243 | // (e.g. argmax([0.07, 0.1, 0.03, 0.75, 0.05]) == 3) |
| 244 | // dataSync() synchronously downloads the tf.tensor values from the GPU so |
| 245 | // that we can use them in our normal CPU JavaScript code |
| 246 | // (for a non-blocking version of this function, use data()). |
| 247 | const axis = 1; |
| 248 | const labels = Array.from(examples.labels.argMax(axis).dataSync()); |
| 249 | const predictions = Array.from(output.argMax(axis).dataSync()); |
| 250 | |
| 251 | ui.showTestResults(examples, predictions, labels); |
| 252 | }); |
| 253 | } |
| 254 | |
| 255 | function createModel() { |
| 256 | let model; |
no test coverage detected