The core of the model consists of an LSTM cell that processes one word at a time and computes probabilities of the possible continuations of the sentence. The memory state of the network is initialized with a vector of zeros and gets updated after reading each word. Also, for comput
()
| 154 | |
| 155 | |
| 156 | def main(): |
| 157 | """ |
| 158 | The core of the model consists of an LSTM cell that processes one word at |
| 159 | a time and computes probabilities of the possible continuations of the |
| 160 | sentence. The memory state of the network is initialized with a vector |
| 161 | of zeros and gets updated after reading each word. Also, for computational |
| 162 | reasons, we will process data in mini-batches of size batch_size. |
| 163 | |
| 164 | """ |
| 165 | param = process_args(sys.argv[1:]) |
| 166 | |
| 167 | if param.model == "small": |
| 168 | init_scale = 0.1 |
| 169 | learning_rate = 1e-3 |
| 170 | max_grad_norm = 5 |
| 171 | num_steps = 20 |
| 172 | hidden_size = 200 |
| 173 | max_epoch = 4 |
| 174 | max_max_epoch = 13 |
| 175 | keep_prob = 1.0 |
| 176 | lr_decay = 0.5 |
| 177 | batch_size = 20 |
| 178 | vocab_size = 10000 |
| 179 | elif param.model == "medium": |
| 180 | init_scale = 0.05 |
| 181 | learning_rate = 1e-3 |
| 182 | max_grad_norm = 5 |
| 183 | # num_layers = 2 |
| 184 | num_steps = 35 |
| 185 | hidden_size = 650 |
| 186 | max_epoch = 6 |
| 187 | max_max_epoch = 39 |
| 188 | keep_prob = 0.5 |
| 189 | lr_decay = 0.8 |
| 190 | batch_size = 20 |
| 191 | vocab_size = 10000 |
| 192 | elif param.model == "large": |
| 193 | init_scale = 0.04 |
| 194 | learning_rate = 1e-3 |
| 195 | max_grad_norm = 10 |
| 196 | # num_layers = 2 |
| 197 | num_steps = 35 |
| 198 | hidden_size = 1500 |
| 199 | max_epoch = 14 |
| 200 | max_max_epoch = 55 |
| 201 | keep_prob = 0.35 |
| 202 | lr_decay = 1 / 1.15 |
| 203 | batch_size = 20 |
| 204 | vocab_size = 10000 |
| 205 | else: |
| 206 | raise ValueError("Invalid model: %s", param.model) |
| 207 | |
| 208 | # Load PTB dataset |
| 209 | train_data, valid_data, test_data, vocab_size = tl.files.load_ptb_dataset() |
| 210 | # train_data = train_data[0:int(100000/5)] # for fast testing |
| 211 | print('len(train_data) {}'.format(len(train_data))) # 929589 a list of int |
| 212 | print('len(valid_data) {}'.format(len(valid_data))) # 73760 a list of int |
| 213 | print('len(test_data) {}'.format(len(test_data))) # 82430 a list of int |
no test coverage detected
searching dependent graphs…