(modelType, numTimeSteps, numFeatures)
| 166 | * @returns A compiled instance of `tf.LayersModel`. |
| 167 | */ |
| 168 | export function buildModel(modelType, numTimeSteps, numFeatures) { |
| 169 | const inputShape = [numTimeSteps, numFeatures]; |
| 170 | |
| 171 | console.log(`modelType = ${modelType}`); |
| 172 | let model; |
| 173 | if (modelType === 'mlp') { |
| 174 | model = buildMLPModel(inputShape); |
| 175 | } else if (modelType === 'mlp-l2') { |
| 176 | model = buildMLPModel(inputShape, tf.regularizers.l2()); |
| 177 | } else if (modelType === 'linear-regression') { |
| 178 | model = buildLinearRegressionModel(inputShape); |
| 179 | } else if (modelType === 'mlp-dropout') { |
| 180 | const regularizer = null; |
| 181 | const dropoutRate = 0.25; |
| 182 | model = buildMLPModel(inputShape, regularizer, dropoutRate); |
| 183 | } else if (modelType === 'simpleRNN') { |
| 184 | model = buildSimpleRNNModel(inputShape); |
| 185 | } else if (modelType === 'gru') { |
| 186 | model = buildGRUModel(inputShape); |
| 187 | // TODO(cais): Add gru-dropout with recurrentDropout. |
| 188 | } else { |
| 189 | throw new Error(`Unsupported model type: ${modelType}`); |
| 190 | } |
| 191 | |
| 192 | model.compile({loss: 'meanAbsoluteError', optimizer: 'rmsprop'}); |
| 193 | model.summary(); |
| 194 | return model; |
| 195 | } |
| 196 | |
| 197 | /** |
| 198 | * Train a model on the Jena weather data. |
no test coverage detected