MCPcopy Index your code
hub / github.com/tensorflow/tfjs-examples / train

Function train

mnist/index.js:124–218  ·  view source on GitHub ↗

* Compile and train the given model. * * @param {tf.Model} model The model to train. * @param {onIterationCallback} onIteration A callback to execute every 10 * batches & epoch end.

(model, onIteration)

Source from the content-addressed store, hash-verified

122 * batches & epoch end.
123 */
124async function train(model, onIteration) {
125 ui.logStatus('Training model...');
126
127 // Now that we've defined our model, we will define our optimizer. The
128 // optimizer will be used to optimize our model's weight values during
129 // training so that we can decrease our training loss and increase our
130 // classification accuracy.
131
132 // We are using rmsprop as our optimizer.
133 // An optimizer is an iterative method for minimizing an loss function.
134 // It tries to find the minimum of our loss function with respect to the
135 // model's weight parameters.
136 const optimizer = 'rmsprop';
137
138 // We compile our model by specifying an optimizer, a loss function, and a
139 // list of metrics that we will use for model evaluation. Here we're using a
140 // categorical crossentropy loss, the standard choice for a multi-class
141 // classification problem like MNIST digits.
142 // The categorical crossentropy loss is differentiable and hence makes
143 // model training possible. But it is not amenable to easy interpretation
144 // by a human. This is why we include a "metric", namely accuracy, which is
145 // simply a measure of how many of the examples are classified correctly.
146 // This metric is not differentiable and hence cannot be used as the loss
147 // function of the model.
148 model.compile({
149 optimizer,
150 loss: 'categoricalCrossentropy',
151 metrics: ['accuracy'],
152 });
153
154 // Batch size is another important hyperparameter. It defines the number of
155 // examples we group together, or batch, between updates to the model's
156 // weights during training. A value that is too low will update weights using
157 // too few examples and will not generalize well. Larger batch sizes require
158 // more memory resources and aren't guaranteed to perform better.
159 const batchSize = 320;
160
161 // Leave out the last 15% of the training data for validation, to monitor
162 // overfitting during training.
163 const validationSplit = 0.15;
164
165 // Get number of training epochs from the UI.
166 const trainEpochs = ui.getTrainEpochs();
167
168 // We'll keep a buffer of loss and accuracy values over time.
169 let trainBatchCount = 0;
170
171 const trainData = data.getTrainData();
172 const testData = data.getTestData();
173
174 const totalNumBatches =
175 Math.ceil(trainData.xs.shape[0] * (1 - validationSplit) / batchSize) *
176 trainEpochs;
177
178 // During the long-running fit() call for model training, we include
179 // callbacks, so that we can plot the loss and accuracy values in the page
180 // as the training progresses.
181 let valAcc;

Callers 1

index.jsFile · 0.70

Calls 2

getTrainDataMethod · 0.45
getTestDataMethod · 0.45

Tested by

no test coverage detected