MCPcopy
hub / github.com/hunkim/PyTorchZeroToAll / translate

Function translate

14_1_seq2seq.py:64–89  ·  view source on GitHub ↗
(enc_input='thisissungkim.iloveyou.', predict_len=100, temperature=0.9)

Source from the content-addressed store, hash-verified

62
63# Translate the given input
64def translate(enc_input='thisissungkim.iloveyou.', predict_len=100, temperature=0.9):
65 input_var = str2tensor(enc_input)
66 encoder_hidden = encoder.init_hidden()
67 encoder_outputs, encoder_hidden = encoder(input_var, encoder_hidden)
68
69 hidden = encoder_hidden
70
71 predicted = ''
72 dec_input = str2tensor(SOS_token)
73 for c in range(predict_len):
74 output, hidden = decoder(dec_input, hidden)
75
76 # Sample from the network as a multi nominal distribution
77 output_dist = output.data.view(-1).div(temperature).exp()
78 top_i = torch.multinomial(output_dist, 1)[0]
79
80 # Stop at the EOS
81 if top_i is EOS_token:
82 break
83
84 predicted_char = chr(top_i)
85 predicted += predicted_char
86
87 dec_input = str2tensor(predicted_char)
88
89 return enc_input, predicted
90
91
92encoder = sm.EncoderRNN(N_CHARS, HIDDEN_SIZE, N_LAYERS)

Callers 1

14_1_seq2seq.pyFile · 0.70

Calls 2

str2tensorFunction · 0.90
init_hiddenMethod · 0.45

Tested by

no test coverage detected