| 184 | |
| 185 | |
| 186 | def sample_line(): |
| 187 | # initial inputs |
| 188 | np_input = np.array([[ word2idx['<sos>'] ]]) |
| 189 | h = np.zeros((1, LATENT_DIM)) |
| 190 | c = np.zeros((1, LATENT_DIM)) |
| 191 | |
| 192 | # so we know when to quit |
| 193 | eos = word2idx['<eos>'] |
| 194 | |
| 195 | # store the output here |
| 196 | output_sentence = [] |
| 197 | |
| 198 | for _ in range(max_sequence_length): |
| 199 | o, h, c = sampling_model.predict([np_input, h, c]) |
| 200 | |
| 201 | # print("o.shape:", o.shape, o[0,0,:10]) |
| 202 | # idx = np.argmax(o[0,0]) |
| 203 | probs = o[0,0] |
| 204 | if np.argmax(probs) == 0: |
| 205 | print("wtf") |
| 206 | probs[0] = 0 |
| 207 | probs /= probs.sum() |
| 208 | idx = np.random.choice(len(probs), p=probs) |
| 209 | if idx == eos: |
| 210 | break |
| 211 | |
| 212 | # accuulate output |
| 213 | output_sentence.append(idx2word.get(idx, '<WTF %s>' % idx)) |
| 214 | |
| 215 | # make the next input into model |
| 216 | np_input[0,0] = idx |
| 217 | |
| 218 | return ' '.join(output_sentence) |
| 219 | |
| 220 | # generate a 4 line poem |
| 221 | while True: |