:param path: path to the model :param start: a `str`. the starting characters :param length: a `int`. the length of text to generate
(path, start, length)
| 126 | |
| 127 | |
| 128 | def sample(path, start, length): |
| 129 | """ |
| 130 | :param path: path to the model |
| 131 | :param start: a `str`. the starting characters |
| 132 | :param length: a `int`. the length of text to generate |
| 133 | """ |
| 134 | # initialize vocabulary and sequence length |
| 135 | param.seq_len = 1 |
| 136 | ds = CharRNNData(param.corpus, 100000) |
| 137 | |
| 138 | pred = OfflinePredictor(PredictConfig( |
| 139 | model=Model(), |
| 140 | session_init=SmartInit(path), |
| 141 | input_names=['input', 'c0', 'h0', 'c1', 'h1'], |
| 142 | output_names=['prob', 'last_state'])) |
| 143 | |
| 144 | # feed the starting sentence |
| 145 | initial = np.zeros((1, param.rnn_size)) |
| 146 | for c in start[:-1]: |
| 147 | x = np.array([[ds.char2idx[c]]], dtype='int32') |
| 148 | _, state = pred(x, initial, initial, initial, initial) |
| 149 | |
| 150 | def pick(prob): |
| 151 | t = np.cumsum(prob) |
| 152 | s = np.sum(prob) |
| 153 | return(int(np.searchsorted(t, np.random.rand(1) * s))) |
| 154 | |
| 155 | # generate more |
| 156 | ret = start |
| 157 | c = start[-1] |
| 158 | for _ in range(length): |
| 159 | x = np.array([[ds.char2idx[c]]], dtype='int32') |
| 160 | prob, state = pred(x, state[0, 0], state[0, 1], state[1, 0], state[1, 1]) |
| 161 | c = ds.chars[pick(prob[0])] |
| 162 | ret += c |
| 163 | print(ret) |
| 164 | |
| 165 | |
| 166 | if __name__ == '__main__': |
no test coverage detected