* Compile and train the given model. * * @param {tf.Model} model The model to train. * @param {onIterationCallback} onIteration A callback to execute every 10 * batches & epoch end.
(model, onIteration)
| 122 | * batches & epoch end. |
| 123 | */ |
| 124 | async function train(model, onIteration) { |
| 125 | ui.logStatus('Training model...'); |
| 126 | |
| 127 | // Now that we've defined our model, we will define our optimizer. The |
| 128 | // optimizer will be used to optimize our model's weight values during |
| 129 | // training so that we can decrease our training loss and increase our |
| 130 | // classification accuracy. |
| 131 | |
| 132 | // We are using rmsprop as our optimizer. |
| 133 | // An optimizer is an iterative method for minimizing an loss function. |
| 134 | // It tries to find the minimum of our loss function with respect to the |
| 135 | // model's weight parameters. |
| 136 | const optimizer = 'rmsprop'; |
| 137 | |
| 138 | // We compile our model by specifying an optimizer, a loss function, and a |
| 139 | // list of metrics that we will use for model evaluation. Here we're using a |
| 140 | // categorical crossentropy loss, the standard choice for a multi-class |
| 141 | // classification problem like MNIST digits. |
| 142 | // The categorical crossentropy loss is differentiable and hence makes |
| 143 | // model training possible. But it is not amenable to easy interpretation |
| 144 | // by a human. This is why we include a "metric", namely accuracy, which is |
| 145 | // simply a measure of how many of the examples are classified correctly. |
| 146 | // This metric is not differentiable and hence cannot be used as the loss |
| 147 | // function of the model. |
| 148 | model.compile({ |
| 149 | optimizer, |
| 150 | loss: 'categoricalCrossentropy', |
| 151 | metrics: ['accuracy'], |
| 152 | }); |
| 153 | |
| 154 | // Batch size is another important hyperparameter. It defines the number of |
| 155 | // examples we group together, or batch, between updates to the model's |
| 156 | // weights during training. A value that is too low will update weights using |
| 157 | // too few examples and will not generalize well. Larger batch sizes require |
| 158 | // more memory resources and aren't guaranteed to perform better. |
| 159 | const batchSize = 320; |
| 160 | |
| 161 | // Leave out the last 15% of the training data for validation, to monitor |
| 162 | // overfitting during training. |
| 163 | const validationSplit = 0.15; |
| 164 | |
| 165 | // Get number of training epochs from the UI. |
| 166 | const trainEpochs = ui.getTrainEpochs(); |
| 167 | |
| 168 | // We'll keep a buffer of loss and accuracy values over time. |
| 169 | let trainBatchCount = 0; |
| 170 | |
| 171 | const trainData = data.getTrainData(); |
| 172 | const testData = data.getTestData(); |
| 173 | |
| 174 | const totalNumBatches = |
| 175 | Math.ceil(trainData.xs.shape[0] * (1 - validationSplit) / batchSize) * |
| 176 | trainEpochs; |
| 177 | |
| 178 | // During the long-running fit() call for model training, we include |
| 179 | // callbacks, so that we can plot the loss and accuracy values in the page |
| 180 | // as the training progresses. |
| 181 | let valAcc; |
no test coverage detected