MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / sample

Function sample

examples/Char-RNN/char-rnn.py:128–163  ·  view source on GitHub ↗

:param path: path to the model :param start: a `str`. the starting characters :param length: a `int`. the length of text to generate

(path, start, length)

Source from the content-addressed store, hash-verified

126
127
128def sample(path, start, length):
129 """
130 :param path: path to the model
131 :param start: a `str`. the starting characters
132 :param length: a `int`. the length of text to generate
133 """
134 # initialize vocabulary and sequence length
135 param.seq_len = 1
136 ds = CharRNNData(param.corpus, 100000)
137
138 pred = OfflinePredictor(PredictConfig(
139 model=Model(),
140 session_init=SmartInit(path),
141 input_names=['input', 'c0', 'h0', 'c1', 'h1'],
142 output_names=['prob', 'last_state']))
143
144 # feed the starting sentence
145 initial = np.zeros((1, param.rnn_size))
146 for c in start[:-1]:
147 x = np.array([[ds.char2idx[c]]], dtype='int32')
148 _, state = pred(x, initial, initial, initial, initial)
149
150 def pick(prob):
151 t = np.cumsum(prob)
152 s = np.sum(prob)
153 return(int(np.searchsorted(t, np.random.rand(1) * s)))
154
155 # generate more
156 ret = start
157 c = start[-1]
158 for _ in range(length):
159 x = np.array([[ds.char2idx[c]]], dtype='int32')
160 prob, state = pred(x, state[0, 0], state[0, 1], state[1, 0], state[1, 1])
161 c = ds.chars[pick(prob[0])]
162 ret += c
163 print(ret)
164
165
166if __name__ == '__main__':

Callers 1

char-rnn.pyFile · 0.70

Calls 6

CharRNNDataClass · 0.85
OfflinePredictorClass · 0.85
PredictConfigClass · 0.85
SmartInitFunction · 0.85
pickFunction · 0.85
ModelClass · 0.70

Tested by

no test coverage detected