* Sets up and trains the classifier.
()
| 62 | * Sets up and trains the classifier. |
| 63 | */ |
| 64 | async function train() { |
| 65 | if (controllerDataset.xs == null) { |
| 66 | throw new Error('Add some examples before training!'); |
| 67 | } |
| 68 | |
| 69 | // Creates a 2-layer fully connected model. By creating a separate model, |
| 70 | // rather than adding layers to the mobilenet model, we "freeze" the weights |
| 71 | // of the mobilenet model, and only train weights from the new model. |
| 72 | model = tf.sequential({ |
| 73 | layers: [ |
| 74 | // Flattens the input to a vector so we can use it in a dense layer. While |
| 75 | // technically a layer, this only performs a reshape (and has no training |
| 76 | // parameters). |
| 77 | tf.layers.flatten( |
| 78 | {inputShape: truncatedMobileNet.outputs[0].shape.slice(1)}), |
| 79 | // Layer 1. |
| 80 | tf.layers.dense({ |
| 81 | units: ui.getDenseUnits(), |
| 82 | activation: 'relu', |
| 83 | kernelInitializer: 'varianceScaling', |
| 84 | useBias: true |
| 85 | }), |
| 86 | // Layer 2. The number of units of the last layer should correspond |
| 87 | // to the number of classes we want to predict. |
| 88 | tf.layers.dense({ |
| 89 | units: NUM_CLASSES, |
| 90 | kernelInitializer: 'varianceScaling', |
| 91 | useBias: false, |
| 92 | activation: 'softmax' |
| 93 | }) |
| 94 | ] |
| 95 | }); |
| 96 | |
| 97 | // Creates the optimizers which drives training of the model. |
| 98 | const optimizer = tf.train.adam(ui.getLearningRate()); |
| 99 | // We use categoricalCrossentropy which is the loss function we use for |
| 100 | // categorical classification which measures the error between our predicted |
| 101 | // probability distribution over classes (probability that an input is of each |
| 102 | // class), versus the label (100% probability in the true class)> |
| 103 | model.compile({optimizer: optimizer, loss: 'categoricalCrossentropy'}); |
| 104 | |
| 105 | // We parameterize batch size as a fraction of the entire dataset because the |
| 106 | // number of examples that are collected depends on how many examples the user |
| 107 | // collects. This allows us to have a flexible batch size. |
| 108 | const batchSize = |
| 109 | Math.floor(controllerDataset.xs.shape[0] * ui.getBatchSizeFraction()); |
| 110 | if (!(batchSize > 0)) { |
| 111 | throw new Error( |
| 112 | `Batch size is 0 or NaN. Please choose a non-zero fraction.`); |
| 113 | } |
| 114 | |
| 115 | // Train the model! Model.fit() will shuffle xs & ys so we don't have to. |
| 116 | model.fit(controllerDataset.xs, controllerDataset.ys, { |
| 117 | batchSize, |
| 118 | epochs: ui.getEpochs(), |
| 119 | callbacks: { |
| 120 | onBatchEnd: async (batch, logs) => { |
| 121 | ui.trainStatus('Loss: ' + logs.loss.toFixed(5)); |