MCPcopy Index your code
hub / github.com/tensorflow/tfjs-examples / train

Function train

iris/python/iris.py:31–76  ·  view source on GitHub ↗

Train a Keras model for Iris data classification and save result as JSON. Args: epochs: Number of epochs to traing the Keras model for. artifacts_dir: Directory to save the model artifacts (model topology JSON, weights and weight manifest) in. sequential: Whether to use a Keras

(epochs,
          artifacts_dir,
          sequential=False)

Source from the content-addressed store, hash-verified

29
30
31def train(epochs,
32 artifacts_dir,
33 sequential=False):
34 """Train a Keras model for Iris data classification and save result as JSON.
35
36 Args:
37 epochs: Number of epochs to traing the Keras model for.
38 artifacts_dir: Directory to save the model artifacts (model topology JSON,
39 weights and weight manifest) in.
40 sequential: Whether to use a Keras Sequential model, instead of the default
41 functional model.
42
43 Returns:
44 Final classification accuracy on the training set.
45 """
46 data_x, data_y = iris_data.load()
47
48 if sequential:
49 model = keras.models.Sequential()
50 model.add(keras.layers.Dense(
51 10, input_shape=[data_x.shape[1]], use_bias=True, activation='sigmoid',
52 name='Dense1'))
53 model.add(keras.layers.Dense(
54 3, use_bias=True, activation='softmax', name='Dense2'))
55 else:
56 iris_x = keras.layers.Input((4,))
57 dense1 = keras.layers.Dense(
58 10, use_bias=True, name='Dense1', activation='sigmoid')(iris_x)
59 dense2 = keras.layers.Dense(
60 3, use_bias=True, name='Dense2', activation='softmax')(dense1)
61 # pylint:disable=redefined-variable-type
62 model = keras.models.Model(inputs=[iris_x], outputs=[dense2])
63 # pylint:enable=redefined-variable-type
64 model.compile(loss='categorical_crossentropy', optimizer='adam')
65
66 model.fit(data_x, data_y, batch_size=8, epochs=epochs)
67
68 # Run prediction on the training set.
69 pred_ys = np.argmax(model.predict(data_x), axis=1)
70 true_ys = np.argmax(data_y, axis=1)
71 final_train_accuracy = np.mean((pred_ys == true_ys).astype(np.float32))
72 print('Accuracy on the training set: %g' % final_train_accuracy)
73
74 tfjs.converters.save_keras_model(model, artifacts_dir)
75
76 return final_train_accuracy
77
78
79def main():

Callers 1

mainFunction · 0.70

Calls 2

loadMethod · 0.45
predictMethod · 0.45

Tested by

no test coverage detected