| 249 | |
| 250 | |
| 251 | def main(param=None): |
| 252 | if not param: |
| 253 | param = { |
| 254 | 'fold': 3, |
| 255 | # 5 folds 0,1,2,3,4 |
| 256 | 'data': 'atis', |
| 257 | 'lr': 0.0970806646812754, |
| 258 | 'verbose': 1, |
| 259 | 'decay': True, |
| 260 | # decay on the learning rate if improvement stops |
| 261 | 'win': 7, |
| 262 | # number of words in the context window |
| 263 | 'nhidden': 200, |
| 264 | # number of hidden units |
| 265 | 'seed': 345, |
| 266 | 'emb_dimension': 50, |
| 267 | # dimension of word embedding |
| 268 | 'nepochs': 60, |
| 269 | # 60 is recommended |
| 270 | 'savemodel': False} |
| 271 | print(param) |
| 272 | |
| 273 | folder_name = os.path.basename(__file__).split('.')[0] |
| 274 | folder = os.path.join(os.path.dirname(__file__), folder_name) |
| 275 | if not os.path.exists(folder): |
| 276 | os.mkdir(folder) |
| 277 | script_path = os.path.dirname(__file__) |
| 278 | |
| 279 | # load the dataset |
| 280 | train_set, valid_set, test_set, dic = atisfold(param['fold']) |
| 281 | |
| 282 | idx2label = dict((k, v) for v, k in dic['labels2idx'].items()) |
| 283 | idx2word = dict((k, v) for v, k in dic['words2idx'].items()) |
| 284 | |
| 285 | train_lex, train_ne, train_y = train_set |
| 286 | valid_lex, valid_ne, valid_y = valid_set |
| 287 | test_lex, test_ne, test_y = test_set |
| 288 | |
| 289 | vocsize = len(dic['words2idx']) |
| 290 | nclasses = len(dic['labels2idx']) |
| 291 | nsentences = len(train_lex) |
| 292 | |
| 293 | groundtruth_valid = [map(lambda x: idx2label[x], y) for y in valid_y] |
| 294 | words_valid = [map(lambda x: idx2word[x], w) for w in valid_lex] |
| 295 | groundtruth_test = [map(lambda x: idx2label[x], y) for y in test_y] |
| 296 | words_test = [map(lambda x: idx2word[x], w) for w in test_lex] |
| 297 | |
| 298 | # instanciate the model |
| 299 | numpy.random.seed(param['seed']) |
| 300 | random.seed(param['seed']) |
| 301 | |
| 302 | rnn = RNNSLU(nh=param['nhidden'], |
| 303 | nc=nclasses, |
| 304 | ne=vocsize, |
| 305 | de=param['emb_dimension'], |
| 306 | cs=param['win']) |
| 307 | |
| 308 | # train with early stopping on validation set |