(decoder, prime_str='A', predict_len=100, temperature=0.8)
| 56 | |
| 57 | |
| 58 | def generate(decoder, prime_str='A', predict_len=100, temperature=0.8): |
| 59 | hidden = decoder.init_hidden() |
| 60 | prime_input = str2tensor(prime_str) |
| 61 | predicted = prime_str |
| 62 | |
| 63 | # Use priming string to "build up" hidden state |
| 64 | for p in range(len(prime_str) - 1): |
| 65 | _, hidden = decoder(prime_input[p], hidden) |
| 66 | |
| 67 | inp = prime_input[-1] |
| 68 | |
| 69 | for p in range(predict_len): |
| 70 | output, hidden = decoder(inp, hidden) |
| 71 | |
| 72 | # Sample from the network as a multinomial distribution |
| 73 | output_dist = output.data.view(-1).div(temperature).exp() |
| 74 | top_i = torch.multinomial(output_dist, 1)[0] |
| 75 | |
| 76 | # Add predicted character to string and use as next input |
| 77 | predicted_char = chr(top_i) |
| 78 | predicted += predicted_char |
| 79 | inp = str2tensor(predicted_char) |
| 80 | |
| 81 | return predicted |
| 82 | |
| 83 | # Train for a given src and target |
| 84 | # It feeds single string to demonstrate seq2seq |
no test coverage detected