MCPcopy Index your code
hub / github.com/tensorflow/tfjs-examples / decodeSequence

Method decodeSequence

translation/index.js:131–174  ·  view source on GitHub ↗
(inputSeq)

Source from the content-addressed store, hash-verified

129 }
130
131 decodeSequence(inputSeq) {
132 // Encode the inputs state vectors.
133 let statesValue = this.encoderModel.predict(inputSeq);
134
135 // Generate empty target sequence of length 1.
136 let targetSeq = tf.buffer([1, 1, this.numDecoderTokens]);
137 // Populate the first character of the target sequence with the start
138 // character.
139 targetSeq.set(1, 0, 0, this.targetTokenIndex['\t']);
140
141 // Sample loop for a batch of sequences.
142 // (to simplify, here we assume that a batch of size 1).
143 let stopCondition = false;
144 let decodedSentence = '';
145 while (!stopCondition) {
146 const predictOutputs =
147 this.decoderModel.predict([targetSeq.toTensor()].concat(statesValue));
148 const outputTokens = predictOutputs[0];
149 const h = predictOutputs[1];
150 const c = predictOutputs[2];
151
152 // Sample a token.
153 // We know that outputTokens.shape is [1, 1, n], so no need for slicing.
154 const logits = outputTokens.reshape([outputTokens.shape[2]]);
155 const sampledTokenIndex = logits.argMax().dataSync()[0];
156 const sampledChar = this.reverseTargetCharIndex[sampledTokenIndex];
157 decodedSentence += sampledChar;
158
159 // Exit condition: either hit max length or find stop character.
160 if (sampledChar === '\n' ||
161 decodedSentence.length > this.maxDecoderSeqLength) {
162 stopCondition = true;
163 }
164
165 // Update the target sequence (of length 1).
166 targetSeq = tf.buffer([1, 1, this.numDecoderTokens]);
167 targetSeq.set(1, 0, 0, sampledTokenIndex);
168
169 // Update states.
170 statesValue = [h, c];
171 }
172
173 return decodedSentence;
174 }
175
176 /** Translate the given English sentence into French. */
177 translate(inputSentence) {

Callers 1

translateMethod · 0.95

Calls 1

predictMethod · 0.45

Tested by

no test coverage detected