| 74 | } |
| 75 | |
| 76 | prepareDecoderModel(model) { |
| 77 | this.numDecoderTokens = model.input[1].shape[2]; |
| 78 | console.log('numDecoderTokens = ' + this.numDecoderTokens); |
| 79 | |
| 80 | const stateH = model.layers[2].output[1]; |
| 81 | const latentDim = stateH.shape[stateH.shape.length - 1]; |
| 82 | console.log('latentDim = ' + latentDim); |
| 83 | const decoderStateInputH = |
| 84 | tf.input({shape: [latentDim], name: 'decoder_state_input_h'}); |
| 85 | const decoderStateInputC = |
| 86 | tf.input({shape: [latentDim], name: 'decoder_state_input_c'}); |
| 87 | const decoderStateInputs = [decoderStateInputH, decoderStateInputC]; |
| 88 | |
| 89 | const decoderLSTM = model.layers[3]; |
| 90 | const decoderInputs = decoderLSTM.input[0]; |
| 91 | const applyOutputs = |
| 92 | decoderLSTM.apply(decoderInputs, {initialState: decoderStateInputs}); |
| 93 | let decoderOutputs = applyOutputs[0]; |
| 94 | const decoderStateH = applyOutputs[1]; |
| 95 | const decoderStateC = applyOutputs[2]; |
| 96 | const decoderStates = [decoderStateH, decoderStateC]; |
| 97 | |
| 98 | const decoderDense = model.layers[4]; |
| 99 | decoderOutputs = decoderDense.apply(decoderOutputs); |
| 100 | this.decoderModel = tf.model({ |
| 101 | inputs: [decoderInputs].concat(decoderStateInputs), |
| 102 | outputs: [decoderOutputs].concat(decoderStates) |
| 103 | }); |
| 104 | } |
| 105 | |
| 106 | /** |
| 107 | * Encode a string (e.g., a sentence) as a Tensor3D that can be fed directly |