MCPcopy Index your code
hub / github.com/tensorflow/tfjs-examples / buildModel

Function buildModel

jena-weather/models.js:168–195  ·  view source on GitHub ↗
(modelType, numTimeSteps, numFeatures)

Source from the content-addressed store, hash-verified

166 * @returns A compiled instance of `tf.LayersModel`.
167 */
168export function buildModel(modelType, numTimeSteps, numFeatures) {
169 const inputShape = [numTimeSteps, numFeatures];
170
171 console.log(`modelType = ${modelType}`);
172 let model;
173 if (modelType === 'mlp') {
174 model = buildMLPModel(inputShape);
175 } else if (modelType === 'mlp-l2') {
176 model = buildMLPModel(inputShape, tf.regularizers.l2());
177 } else if (modelType === 'linear-regression') {
178 model = buildLinearRegressionModel(inputShape);
179 } else if (modelType === 'mlp-dropout') {
180 const regularizer = null;
181 const dropoutRate = 0.25;
182 model = buildMLPModel(inputShape, regularizer, dropoutRate);
183 } else if (modelType === 'simpleRNN') {
184 model = buildSimpleRNNModel(inputShape);
185 } else if (modelType === 'gru') {
186 model = buildGRUModel(inputShape);
187 // TODO(cais): Add gru-dropout with recurrentDropout.
188 } else {
189 throw new Error(`Unsupported model type: ${modelType}`);
190 }
191
192 model.compile({loss: 'meanAbsoluteError', optimizer: 'rmsprop'});
193 model.summary();
194 return model;
195}
196
197/**
198 * Train a model on the Jena weather data.

Callers 2

index.jsFile · 0.90
mainFunction · 0.90

Calls 4

buildMLPModelFunction · 0.85
buildSimpleRNNModelFunction · 0.85
buildGRUModelFunction · 0.85

Tested by

no test coverage detected