()
| 216 | |
| 217 | |
| 218 | def main(): |
| 219 | train, test, word2idx = get_ptb_data() |
| 220 | |
| 221 | train = train[:100] |
| 222 | test = test[:100] |
| 223 | |
| 224 | V = len(word2idx) |
| 225 | D = 80 |
| 226 | K = 5 |
| 227 | |
| 228 | model = TNN(V, D, K, tf.nn.relu) |
| 229 | model.fit(train) |
| 230 | print("train accuracy:", model.score(None)) |
| 231 | print("test accuracy:", model.score(test)) |
| 232 | |
| 233 | |
| 234 | if __name__ == '__main__': |
no test coverage detected