MCPcopy
hub / github.com/tensorflow/tfjs / trainModelRunner

Function trainModelRunner

tfjs-react-native/integration_rn59/components/ml.ts:128–145  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

126 * A runner that trains a model.
127 */
128export async function trainModelRunner() {
129 // Define a model for linear regression.
130 const model = tf.sequential();
131 model.add(tf.layers.dense({ units: 5, inputShape: [1] }));
132 model.add(tf.layers.dense({ units: 1 }));
133 model.compile({ loss: "meanSquaredError", optimizer: "sgd" });
134
135 // Generate some synthetic data for training.
136 const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
137 const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
138
139 return async () => {
140 // Train the model using the data.
141 await model.fit(xs, ys, { epochs: 20 });
142
143 return "done";
144 };
145}
146
147/**
148 * A runner that saves and loads a model to/from asyncStorage

Callers

nothing calls this directly

Calls 3

addMethod · 0.65
compileMethod · 0.45
fitMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…