(
model, jenaWeatherData, normalize, includeDateTime, lookBack, step, delay,
batchSize, epochs, customCallback)
| 216 | * `onEpochEnd` fields. |
| 217 | */ |
| 218 | export async function trainModel( |
| 219 | model, jenaWeatherData, normalize, includeDateTime, lookBack, step, delay, |
| 220 | batchSize, epochs, customCallback) { |
| 221 | const trainShuffle = true; |
| 222 | const trainDataset = |
| 223 | tf.data |
| 224 | .generator( |
| 225 | () => jenaWeatherData.getNextBatchFunction( |
| 226 | trainShuffle, lookBack, delay, batchSize, step, TRAIN_MIN_ROW, |
| 227 | TRAIN_MAX_ROW, normalize, includeDateTime)) |
| 228 | .prefetch(8); |
| 229 | const evalShuffle = false; |
| 230 | const valDataset = tf.data.generator( |
| 231 | () => jenaWeatherData.getNextBatchFunction( |
| 232 | evalShuffle, lookBack, delay, batchSize, step, VAL_MIN_ROW, |
| 233 | VAL_MAX_ROW, normalize, includeDateTime)); |
| 234 | |
| 235 | await model.fitDataset(trainDataset, { |
| 236 | batchesPerEpoch: 500, |
| 237 | epochs, |
| 238 | callbacks: customCallback, |
| 239 | validationData: valDataset |
| 240 | }); |
| 241 | } |
no test coverage detected