| 129 | } |
| 130 | |
| 131 | decodeSequence(inputSeq) { |
| 132 | // Encode the inputs state vectors. |
| 133 | let statesValue = this.encoderModel.predict(inputSeq); |
| 134 | |
| 135 | // Generate empty target sequence of length 1. |
| 136 | let targetSeq = tf.buffer([1, 1, this.numDecoderTokens]); |
| 137 | // Populate the first character of the target sequence with the start |
| 138 | // character. |
| 139 | targetSeq.set(1, 0, 0, this.targetTokenIndex['\t']); |
| 140 | |
| 141 | // Sample loop for a batch of sequences. |
| 142 | // (to simplify, here we assume that a batch of size 1). |
| 143 | let stopCondition = false; |
| 144 | let decodedSentence = ''; |
| 145 | while (!stopCondition) { |
| 146 | const predictOutputs = |
| 147 | this.decoderModel.predict([targetSeq.toTensor()].concat(statesValue)); |
| 148 | const outputTokens = predictOutputs[0]; |
| 149 | const h = predictOutputs[1]; |
| 150 | const c = predictOutputs[2]; |
| 151 | |
| 152 | // Sample a token. |
| 153 | // We know that outputTokens.shape is [1, 1, n], so no need for slicing. |
| 154 | const logits = outputTokens.reshape([outputTokens.shape[2]]); |
| 155 | const sampledTokenIndex = logits.argMax().dataSync()[0]; |
| 156 | const sampledChar = this.reverseTargetCharIndex[sampledTokenIndex]; |
| 157 | decodedSentence += sampledChar; |
| 158 | |
| 159 | // Exit condition: either hit max length or find stop character. |
| 160 | if (sampledChar === '\n' || |
| 161 | decodedSentence.length > this.maxDecoderSeqLength) { |
| 162 | stopCondition = true; |
| 163 | } |
| 164 | |
| 165 | // Update the target sequence (of length 1). |
| 166 | targetSeq = tf.buffer([1, 1, this.numDecoderTokens]); |
| 167 | targetSeq.set(1, 0, 0, sampledTokenIndex); |
| 168 | |
| 169 | // Update states. |
| 170 | statesValue = [h, c]; |
| 171 | } |
| 172 | |
| 173 | return decodedSentence; |
| 174 | } |
| 175 | |
| 176 | /** Translate the given English sentence into French. */ |
| 177 | translate(inputSentence) { |