| 126 | * A runner that trains a model. |
| 127 | */ |
| 128 | export async function trainModelRunner() { |
| 129 | // Define a model for linear regression. |
| 130 | const model = tf.sequential(); |
| 131 | model.add(tf.layers.dense({ units: 5, inputShape: [1] })); |
| 132 | model.add(tf.layers.dense({ units: 1 })); |
| 133 | model.compile({ loss: "meanSquaredError", optimizer: "sgd" }); |
| 134 | |
| 135 | // Generate some synthetic data for training. |
| 136 | const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]); |
| 137 | const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]); |
| 138 | |
| 139 | return async () => { |
| 140 | // Train the model using the data. |
| 141 | await model.fit(xs, ys, { epochs: 20 }); |
| 142 | |
| 143 | return "done"; |
| 144 | }; |
| 145 | } |
| 146 | |
| 147 | /** |
| 148 | * A runner that saves and loads a model to/from asyncStorage |