(enc_input='thisissungkim.iloveyou.', predict_len=100, temperature=0.9)
| 78 | |
| 79 | # Translate the given input |
| 80 | def translate(enc_input='thisissungkim.iloveyou.', predict_len=100, temperature=0.9): |
| 81 | input_var = str2tensor(enc_input) |
| 82 | encoder_hidden = encoder.init_hidden() |
| 83 | encoder_outputs, encoder_hidden = encoder(input_var, encoder_hidden) |
| 84 | |
| 85 | hidden = encoder_hidden |
| 86 | |
| 87 | predicted = '' |
| 88 | dec_input = str2tensor(SOS_token) |
| 89 | attentions = [] |
| 90 | for c in range(predict_len): |
| 91 | output, hidden, attention = decoder(dec_input, hidden, encoder_outputs) |
| 92 | # Sample from the network as a multi nominal distribution |
| 93 | output_dist = output.data.view(-1).div(temperature).exp() |
| 94 | top_i = torch.multinomial(output_dist, 1)[0] |
| 95 | attentions.append(attention.view(-1).data.cpu().numpy().tolist()) |
| 96 | |
| 97 | # Stop at the EOS |
| 98 | if top_i is EOS_token: |
| 99 | break |
| 100 | |
| 101 | predicted_char = chr(top_i) |
| 102 | predicted += predicted_char |
| 103 | |
| 104 | dec_input = str2tensor(predicted_char) |
| 105 | |
| 106 | return predicted, attentions |
| 107 | |
| 108 | |
| 109 | if __name__ == '__main__': |
no test coverage detected