MCPcopy
hub / github.com/tensorflow/tfjs-examples / main

Function main

translation/python/translation.py:234–300  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

232
233
234def main():
235 (input_texts, _, max_decoder_seq_length,
236 num_encoder_tokens, num_decoder_tokens,
237 __, target_token_index,
238 encoder_input_data, decoder_input_data, decoder_target_data) = read_data()
239
240 (encoder_inputs, encoder_states, decoder_inputs, decoder_lstm,
241 decoder_dense, model) = seq2seq_model(
242 num_encoder_tokens, num_decoder_tokens, FLAGS.latent_dim)
243
244 # Run training.
245 model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
246 model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
247 batch_size=FLAGS.batch_size,
248 epochs=FLAGS.epochs,
249 validation_split=0.2)
250
251 tfjs.converters.save_keras_model(model, FLAGS.artifacts_dir)
252
253 # Next: inference mode (sampling).
254 # Here's the drill:
255 # 1) encode input and retrieve initial decoder state
256 # 2) run one step of decoder with this initial state
257 # and a "start of sequence" token as target.
258 # Output will be the next target token
259 # 3) Repeat with the current target token and current states
260
261 # Define sampling models
262 encoder_model = Model(encoder_inputs, encoder_states)
263
264 decoder_state_input_h = Input(shape=(FLAGS.latent_dim,))
265 decoder_state_input_c = Input(shape=(FLAGS.latent_dim,))
266 decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
267 decoder_outputs, state_h, state_c = decoder_lstm(
268 decoder_inputs, initial_state=decoder_states_inputs)
269 decoder_states = [state_h, state_c]
270 decoder_outputs = decoder_dense(decoder_outputs)
271 decoder_model = Model(
272 [decoder_inputs] + decoder_states_inputs,
273 [decoder_outputs] + decoder_states)
274
275 # Reverse-lookup token index to decode sequences back to
276 # something readable.
277 reverse_target_char_index = dict(
278 (i, char) for char, i in target_token_index.items())
279
280 target_begin_index = target_token_index['\t']
281
282 for seq_index in range(FLAGS.num_test_sentences):
283 # Take one sequence (part of the training set)
284 # for trying out decoding.
285 input_seq = encoder_input_data[seq_index: seq_index + 1]
286 # Get expected output
287 target_seq = decoder_target_data[seq_index]
288 # One-hot to index
289 target_seq = [
290 reverse_target_char_index[np.argmax(c)] for c in target_seq
291 ]

Callers 1

translation.pyFile · 0.70

Calls 3

read_dataFunction · 0.85
seq2seq_modelFunction · 0.85
decode_sequenceFunction · 0.85

Tested by

no test coverage detected