MCPcopy
hub / github.com/tensorflow/tfjs-examples / decode_sequence

Function decode_sequence

translation/python/translation.py:171–231  ·  view source on GitHub ↗

Decode (i.e., translate) an encoded sentence. Args: input_seq: A `numpy.ndarray` of shape `(1, max_encoder_seq_length, num_encoder_tokens)`. encoder_model: A `keras.Model` instance for the encoder. decoder_model: A `keras.Model` instance for the decoder. num_decoder_tokens:

(input_seq,
                    encoder_model,
                    decoder_model,
                    num_decoder_tokens,
                    target_begin_index,
                    reverse_target_char_index,
                    max_decoder_seq_length)

Source from the content-addressed store, hash-verified

169
170
171def decode_sequence(input_seq,
172 encoder_model,
173 decoder_model,
174 num_decoder_tokens,
175 target_begin_index,
176 reverse_target_char_index,
177 max_decoder_seq_length):
178 """Decode (i.e., translate) an encoded sentence.
179
180 Args:
181 input_seq: A `numpy.ndarray` of shape
182 `(1, max_encoder_seq_length, num_encoder_tokens)`.
183 encoder_model: A `keras.Model` instance for the encoder.
184 decoder_model: A `keras.Model` instance for the decoder.
185 num_decoder_tokens: Number of unique tokens for the decoder.
186 target_begin_index: An `int`: the index for the beginning token of the
187 decoder.
188 reverse_target_char_index: A lookup table for the target characters, i.e.,
189 a map from `int` index to target character.
190 max_decoder_seq_length: Maximum allowed sequence length output by the
191 decoder.
192
193 Returns:
194 The result of the decoding (i.e., translation) as a string.
195 """
196
197 # Encode the input as state vectors.
198 states_value = encoder_model.predict(input_seq)
199
200 # Generate empty target sequence of length 1.
201 target_seq = np.zeros((1, 1, num_decoder_tokens))
202 # Populate the first character of target sequence with the start character.
203 target_seq[0, 0, target_begin_index] = 1.
204
205 # Sampling loop for a batch of sequences
206 # (to simplify, here we assume a batch of size 1).
207 stop_condition = False
208 decoded_sentence = ''
209 while not stop_condition:
210 output_tokens, h, c = decoder_model.predict(
211 [target_seq] + states_value)
212
213 # Sample a token
214 sampled_token_index = np.argmax(output_tokens[0, -1, :])
215 sampled_char = reverse_target_char_index[sampled_token_index]
216 decoded_sentence += sampled_char
217
218 # Exit condition: either hit max length
219 # or find stop character.
220 if (sampled_char == '\n' or
221 len(decoded_sentence) > max_decoder_seq_length):
222 stop_condition = True
223
224 # Update the target sequence (of length 1).
225 target_seq = np.zeros((1, 1, num_decoder_tokens))
226 target_seq[0, 0, sampled_token_index] = 1.
227
228 # Update states

Callers 1

mainFunction · 0.85

Calls 1

predictMethod · 0.45

Tested by

no test coverage detected