MCPcopy
hub / github.com/hunkim/PyTorchZeroToAll / translate

Function translate

14_2_seq2seq_att.py:80–106  ·  view source on GitHub ↗
(enc_input='thisissungkim.iloveyou.', predict_len=100, temperature=0.9)

Source from the content-addressed store, hash-verified

78
79# Translate the given input
80def translate(enc_input='thisissungkim.iloveyou.', predict_len=100, temperature=0.9):
81 input_var = str2tensor(enc_input)
82 encoder_hidden = encoder.init_hidden()
83 encoder_outputs, encoder_hidden = encoder(input_var, encoder_hidden)
84
85 hidden = encoder_hidden
86
87 predicted = ''
88 dec_input = str2tensor(SOS_token)
89 attentions = []
90 for c in range(predict_len):
91 output, hidden, attention = decoder(dec_input, hidden, encoder_outputs)
92 # Sample from the network as a multi nominal distribution
93 output_dist = output.data.view(-1).div(temperature).exp()
94 top_i = torch.multinomial(output_dist, 1)[0]
95 attentions.append(attention.view(-1).data.cpu().numpy().tolist())
96
97 # Stop at the EOS
98 if top_i is EOS_token:
99 break
100
101 predicted_char = chr(top_i)
102 predicted += predicted_char
103
104 dec_input = str2tensor(predicted_char)
105
106 return predicted, attentions
107
108
109if __name__ == '__main__':

Callers 1

Calls 2

str2tensorFunction · 0.90
init_hiddenMethod · 0.45

Tested by

no test coverage detected