()
| 232 | |
| 233 | |
| 234 | def main(): |
| 235 | (input_texts, _, max_decoder_seq_length, |
| 236 | num_encoder_tokens, num_decoder_tokens, |
| 237 | __, target_token_index, |
| 238 | encoder_input_data, decoder_input_data, decoder_target_data) = read_data() |
| 239 | |
| 240 | (encoder_inputs, encoder_states, decoder_inputs, decoder_lstm, |
| 241 | decoder_dense, model) = seq2seq_model( |
| 242 | num_encoder_tokens, num_decoder_tokens, FLAGS.latent_dim) |
| 243 | |
| 244 | # Run training. |
| 245 | model.compile(optimizer='rmsprop', loss='categorical_crossentropy') |
| 246 | model.fit([encoder_input_data, decoder_input_data], decoder_target_data, |
| 247 | batch_size=FLAGS.batch_size, |
| 248 | epochs=FLAGS.epochs, |
| 249 | validation_split=0.2) |
| 250 | |
| 251 | tfjs.converters.save_keras_model(model, FLAGS.artifacts_dir) |
| 252 | |
| 253 | # Next: inference mode (sampling). |
| 254 | # Here's the drill: |
| 255 | # 1) encode input and retrieve initial decoder state |
| 256 | # 2) run one step of decoder with this initial state |
| 257 | # and a "start of sequence" token as target. |
| 258 | # Output will be the next target token |
| 259 | # 3) Repeat with the current target token and current states |
| 260 | |
| 261 | # Define sampling models |
| 262 | encoder_model = Model(encoder_inputs, encoder_states) |
| 263 | |
| 264 | decoder_state_input_h = Input(shape=(FLAGS.latent_dim,)) |
| 265 | decoder_state_input_c = Input(shape=(FLAGS.latent_dim,)) |
| 266 | decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c] |
| 267 | decoder_outputs, state_h, state_c = decoder_lstm( |
| 268 | decoder_inputs, initial_state=decoder_states_inputs) |
| 269 | decoder_states = [state_h, state_c] |
| 270 | decoder_outputs = decoder_dense(decoder_outputs) |
| 271 | decoder_model = Model( |
| 272 | [decoder_inputs] + decoder_states_inputs, |
| 273 | [decoder_outputs] + decoder_states) |
| 274 | |
| 275 | # Reverse-lookup token index to decode sequences back to |
| 276 | # something readable. |
| 277 | reverse_target_char_index = dict( |
| 278 | (i, char) for char, i in target_token_index.items()) |
| 279 | |
| 280 | target_begin_index = target_token_index['\t'] |
| 281 | |
| 282 | for seq_index in range(FLAGS.num_test_sentences): |
| 283 | # Take one sequence (part of the training set) |
| 284 | # for trying out decoding. |
| 285 | input_seq = encoder_input_data[seq_index: seq_index + 1] |
| 286 | # Get expected output |
| 287 | target_seq = decoder_target_data[seq_index] |
| 288 | # One-hot to index |
| 289 | target_seq = [ |
| 290 | reverse_target_char_index[np.argmax(c)] for c in target_seq |
| 291 | ] |
no test coverage detected