* Train the LSTM model. * * @param {number} numEpochs Number of epochs to train the model for. * @param {number} examplesPerEpoch Number of epochs to use in each training * epochs. * @param {number} batchSize Batch size to use during training. * @param {number} validationSplit Va
(numEpochs, examplesPerEpoch, batchSize, validationSplit)
| 85 | * training epochs. |
| 86 | */ |
| 87 | async fitModel(numEpochs, examplesPerEpoch, batchSize, validationSplit) { |
| 88 | let batchCount = 0; |
| 89 | const batchesPerEpoch = examplesPerEpoch / batchSize; |
| 90 | const totalBatches = numEpochs * batchesPerEpoch; |
| 91 | let t = new Date().getTime(); |
| 92 | |
| 93 | onTrainBegin(); |
| 94 | const callbacks = { |
| 95 | onBatchEnd: async (batch, logs) => { |
| 96 | // Calculate the training speed in the current batch, in # of |
| 97 | // examples per second. |
| 98 | const t1 = new Date().getTime(); |
| 99 | const examplesPerSec = batchSize / ((t1 - t) / 1e3); |
| 100 | t = t1; |
| 101 | onTrainBatchEnd(logs, ++batchCount / totalBatches, examplesPerSec); |
| 102 | }, |
| 103 | onEpochEnd: async (epoch, logs) => { |
| 104 | onTrainEpochEnd(logs); |
| 105 | } |
| 106 | }; |
| 107 | |
| 108 | await model.fitModel( |
| 109 | this.model, this.textData_, numEpochs, examplesPerEpoch, batchSize, |
| 110 | validationSplit, callbacks); |
| 111 | } |
| 112 | |
| 113 | /** |
| 114 | * Generate text using the LSTM model. |
no test coverage detected