MCPcopy
hub / github.com/tensorlayer/TensorLayer / main

Function main

examples/text_ptb/tutorial_ptb_lstm.py:156–314  ·  view source on GitHub ↗

The core of the model consists of an LSTM cell that processes one word at a time and computes probabilities of the possible continuations of the sentence. The memory state of the network is initialized with a vector of zeros and gets updated after reading each word. Also, for comput

()

Source from the content-addressed store, hash-verified

154
155
156def main():
157 """
158 The core of the model consists of an LSTM cell that processes one word at
159 a time and computes probabilities of the possible continuations of the
160 sentence. The memory state of the network is initialized with a vector
161 of zeros and gets updated after reading each word. Also, for computational
162 reasons, we will process data in mini-batches of size batch_size.
163
164 """
165 param = process_args(sys.argv[1:])
166
167 if param.model == "small":
168 init_scale = 0.1
169 learning_rate = 1e-3
170 max_grad_norm = 5
171 num_steps = 20
172 hidden_size = 200
173 max_epoch = 4
174 max_max_epoch = 13
175 keep_prob = 1.0
176 lr_decay = 0.5
177 batch_size = 20
178 vocab_size = 10000
179 elif param.model == "medium":
180 init_scale = 0.05
181 learning_rate = 1e-3
182 max_grad_norm = 5
183 # num_layers = 2
184 num_steps = 35
185 hidden_size = 650
186 max_epoch = 6
187 max_max_epoch = 39
188 keep_prob = 0.5
189 lr_decay = 0.8
190 batch_size = 20
191 vocab_size = 10000
192 elif param.model == "large":
193 init_scale = 0.04
194 learning_rate = 1e-3
195 max_grad_norm = 10
196 # num_layers = 2
197 num_steps = 35
198 hidden_size = 1500
199 max_epoch = 14
200 max_max_epoch = 55
201 keep_prob = 0.35
202 lr_decay = 1 / 1.15
203 batch_size = 20
204 vocab_size = 10000
205 else:
206 raise ValueError("Invalid model: %s", param.model)
207
208 # Load PTB dataset
209 train_data, valid_data, test_data, vocab_size = tl.files.load_ptb_dataset()
210 # train_data = train_data[0:int(100000/5)] # for fast testing
211 print('len(train_data) {}'.format(len(train_data))) # 929589 a list of int
212 print('len(valid_data) {}'.format(len(valid_data))) # 73760 a list of int
213 print('len(test_data) {}'.format(len(test_data))) # 82430 a list of int

Callers 1

Calls 5

process_argsFunction · 0.85
PTB_NetClass · 0.85
gradientMethod · 0.80
evalMethod · 0.80
trainMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…