(sampleLen, charSetSize, lstmLayerSizes)
| 30 | * `[null, charSetSize]`. |
| 31 | */ |
| 32 | export function createModel(sampleLen, charSetSize, lstmLayerSizes) { |
| 33 | if (!Array.isArray(lstmLayerSizes)) { |
| 34 | lstmLayerSizes = [lstmLayerSizes]; |
| 35 | } |
| 36 | |
| 37 | const model = tf.sequential(); |
| 38 | for (let i = 0; i < lstmLayerSizes.length; ++i) { |
| 39 | const lstmLayerSize = lstmLayerSizes[i]; |
| 40 | model.add(tf.layers.lstm({ |
| 41 | units: lstmLayerSize, |
| 42 | returnSequences: i < lstmLayerSizes.length - 1, |
| 43 | inputShape: i === 0 ? [sampleLen, charSetSize] : undefined |
| 44 | })); |
| 45 | } |
| 46 | model.add( |
| 47 | tf.layers.dense({units: charSetSize, activation: 'softmax'})); |
| 48 | |
| 49 | return model; |
| 50 | } |
| 51 | |
| 52 | export function compileModel(model, learningRate) { |
| 53 | const optimizer = tf.train.rmsprop(learningRate); |
no outgoing calls
no test coverage detected