Decode (i.e., translate) an encoded sentence. Args: input_seq: A `numpy.ndarray` of shape `(1, max_encoder_seq_length, num_encoder_tokens)`. encoder_model: A `keras.Model` instance for the encoder. decoder_model: A `keras.Model` instance for the decoder. num_decoder_tokens:
(input_seq,
encoder_model,
decoder_model,
num_decoder_tokens,
target_begin_index,
reverse_target_char_index,
max_decoder_seq_length)
| 169 | |
| 170 | |
| 171 | def decode_sequence(input_seq, |
| 172 | encoder_model, |
| 173 | decoder_model, |
| 174 | num_decoder_tokens, |
| 175 | target_begin_index, |
| 176 | reverse_target_char_index, |
| 177 | max_decoder_seq_length): |
| 178 | """Decode (i.e., translate) an encoded sentence. |
| 179 | |
| 180 | Args: |
| 181 | input_seq: A `numpy.ndarray` of shape |
| 182 | `(1, max_encoder_seq_length, num_encoder_tokens)`. |
| 183 | encoder_model: A `keras.Model` instance for the encoder. |
| 184 | decoder_model: A `keras.Model` instance for the decoder. |
| 185 | num_decoder_tokens: Number of unique tokens for the decoder. |
| 186 | target_begin_index: An `int`: the index for the beginning token of the |
| 187 | decoder. |
| 188 | reverse_target_char_index: A lookup table for the target characters, i.e., |
| 189 | a map from `int` index to target character. |
| 190 | max_decoder_seq_length: Maximum allowed sequence length output by the |
| 191 | decoder. |
| 192 | |
| 193 | Returns: |
| 194 | The result of the decoding (i.e., translation) as a string. |
| 195 | """ |
| 196 | |
| 197 | # Encode the input as state vectors. |
| 198 | states_value = encoder_model.predict(input_seq) |
| 199 | |
| 200 | # Generate empty target sequence of length 1. |
| 201 | target_seq = np.zeros((1, 1, num_decoder_tokens)) |
| 202 | # Populate the first character of target sequence with the start character. |
| 203 | target_seq[0, 0, target_begin_index] = 1. |
| 204 | |
| 205 | # Sampling loop for a batch of sequences |
| 206 | # (to simplify, here we assume a batch of size 1). |
| 207 | stop_condition = False |
| 208 | decoded_sentence = '' |
| 209 | while not stop_condition: |
| 210 | output_tokens, h, c = decoder_model.predict( |
| 211 | [target_seq] + states_value) |
| 212 | |
| 213 | # Sample a token |
| 214 | sampled_token_index = np.argmax(output_tokens[0, -1, :]) |
| 215 | sampled_char = reverse_target_char_index[sampled_token_index] |
| 216 | decoded_sentence += sampled_char |
| 217 | |
| 218 | # Exit condition: either hit max length |
| 219 | # or find stop character. |
| 220 | if (sampled_char == '\n' or |
| 221 | len(decoded_sentence) > max_decoder_seq_length): |
| 222 | stop_condition = True |
| 223 | |
| 224 | # Update the target sequence (of length 1). |
| 225 | target_seq = np.zeros((1, 1, num_decoder_tokens)) |
| 226 | target_seq[0, 0, sampled_token_index] = 1. |
| 227 | |
| 228 | # Update states |