(enc_input='thisissungkim.iloveyou.', predict_len=100, temperature=0.9)
| 62 | |
| 63 | # Translate the given input |
| 64 | def translate(enc_input='thisissungkim.iloveyou.', predict_len=100, temperature=0.9): |
| 65 | input_var = str2tensor(enc_input) |
| 66 | encoder_hidden = encoder.init_hidden() |
| 67 | encoder_outputs, encoder_hidden = encoder(input_var, encoder_hidden) |
| 68 | |
| 69 | hidden = encoder_hidden |
| 70 | |
| 71 | predicted = '' |
| 72 | dec_input = str2tensor(SOS_token) |
| 73 | for c in range(predict_len): |
| 74 | output, hidden = decoder(dec_input, hidden) |
| 75 | |
| 76 | # Sample from the network as a multi nominal distribution |
| 77 | output_dist = output.data.view(-1).div(temperature).exp() |
| 78 | top_i = torch.multinomial(output_dist, 1)[0] |
| 79 | |
| 80 | # Stop at the EOS |
| 81 | if top_i is EOS_token: |
| 82 | break |
| 83 | |
| 84 | predicted_char = chr(top_i) |
| 85 | predicted += predicted_char |
| 86 | |
| 87 | dec_input = str2tensor(predicted_char) |
| 88 | |
| 89 | return enc_input, predicted |
| 90 | |
| 91 | |
| 92 | encoder = sm.EncoderRNN(N_CHARS, HIDDEN_SIZE, N_LAYERS) |
no test coverage detected