MCPcopy Index your code
hub / github.com/tensorflow/tfjs-examples / train

Function train

webcam-transfer-learning/index.js:64–125  ·  view source on GitHub ↗

* Sets up and trains the classifier.

()

Source from the content-addressed store, hash-verified

62 * Sets up and trains the classifier.
63 */
64async function train() {
65 if (controllerDataset.xs == null) {
66 throw new Error('Add some examples before training!');
67 }
68
69 // Creates a 2-layer fully connected model. By creating a separate model,
70 // rather than adding layers to the mobilenet model, we "freeze" the weights
71 // of the mobilenet model, and only train weights from the new model.
72 model = tf.sequential({
73 layers: [
74 // Flattens the input to a vector so we can use it in a dense layer. While
75 // technically a layer, this only performs a reshape (and has no training
76 // parameters).
77 tf.layers.flatten(
78 {inputShape: truncatedMobileNet.outputs[0].shape.slice(1)}),
79 // Layer 1.
80 tf.layers.dense({
81 units: ui.getDenseUnits(),
82 activation: 'relu',
83 kernelInitializer: 'varianceScaling',
84 useBias: true
85 }),
86 // Layer 2. The number of units of the last layer should correspond
87 // to the number of classes we want to predict.
88 tf.layers.dense({
89 units: NUM_CLASSES,
90 kernelInitializer: 'varianceScaling',
91 useBias: false,
92 activation: 'softmax'
93 })
94 ]
95 });
96
97 // Creates the optimizers which drives training of the model.
98 const optimizer = tf.train.adam(ui.getLearningRate());
99 // We use categoricalCrossentropy which is the loss function we use for
100 // categorical classification which measures the error between our predicted
101 // probability distribution over classes (probability that an input is of each
102 // class), versus the label (100% probability in the true class)>
103 model.compile({optimizer: optimizer, loss: 'categoricalCrossentropy'});
104
105 // We parameterize batch size as a fraction of the entire dataset because the
106 // number of examples that are collected depends on how many examples the user
107 // collects. This allows us to have a flexible batch size.
108 const batchSize =
109 Math.floor(controllerDataset.xs.shape[0] * ui.getBatchSizeFraction());
110 if (!(batchSize > 0)) {
111 throw new Error(
112 `Batch size is 0 or NaN. Please choose a non-zero fraction.`);
113 }
114
115 // Train the model! Model.fit() will shuffle xs & ys so we don't have to.
116 model.fit(controllerDataset.xs, controllerDataset.ys, {
117 batchSize,
118 epochs: ui.getEpochs(),
119 callbacks: {
120 onBatchEnd: async (batch, logs) => {
121 ui.trainStatus('Loss: ' + logs.loss.toFixed(5));

Callers 1

index.jsFile · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected