MCPcopy
hub / github.com/hunkim/PyTorchZeroToAll / generate

Function generate

13_3_char_rnn.py:58–81  ·  view source on GitHub ↗
(decoder, prime_str='A', predict_len=100, temperature=0.8)

Source from the content-addressed store, hash-verified

56
57
58def generate(decoder, prime_str='A', predict_len=100, temperature=0.8):
59 hidden = decoder.init_hidden()
60 prime_input = str2tensor(prime_str)
61 predicted = prime_str
62
63 # Use priming string to "build up" hidden state
64 for p in range(len(prime_str) - 1):
65 _, hidden = decoder(prime_input[p], hidden)
66
67 inp = prime_input[-1]
68
69 for p in range(predict_len):
70 output, hidden = decoder(inp, hidden)
71
72 # Sample from the network as a multinomial distribution
73 output_dist = output.data.view(-1).div(temperature).exp()
74 top_i = torch.multinomial(output_dist, 1)[0]
75
76 # Add predicted character to string and use as next input
77 predicted_char = chr(top_i)
78 predicted += predicted_char
79 inp = str2tensor(predicted_char)
80
81 return predicted
82
83# Train for a given src and target
84# It feeds single string to demonstrate seq2seq

Callers 1

13_3_char_rnn.pyFile · 0.85

Calls 2

str2tensorFunction · 0.70
init_hiddenMethod · 0.45

Tested by

no test coverage detected