(params)
| 115 | return out |
| 116 | |
| 117 | def main(params): |
| 118 | batch_size = params['batch_size'] |
| 119 | dataset = params['dataset'] |
| 120 | word_count_threshold = params['word_count_threshold'] |
| 121 | do_grad_check = params['do_grad_check'] |
| 122 | max_epochs = params['max_epochs'] |
| 123 | host = socket.gethostname() # get computer hostname |
| 124 | |
| 125 | # fetch the data provider |
| 126 | dp = getDataProvider(dataset) |
| 127 | |
| 128 | misc = {} # stores various misc items that need to be passed around the framework |
| 129 | |
| 130 | # go over all training sentences and find the vocabulary we want to use, i.e. the words that occur |
| 131 | # at least word_count_threshold number of times |
| 132 | misc['wordtoix'], misc['ixtoword'], bias_init_vector = preProBuildWordVocab(dp.iterSentences('train'), word_count_threshold) |
| 133 | |
| 134 | # delegate the initialization of the model to the Generator class |
| 135 | BatchGenerator = decodeGenerator(params) |
| 136 | init_struct = BatchGenerator.init(params, misc) |
| 137 | model, misc['update'], misc['regularize'] = (init_struct['model'], init_struct['update'], init_struct['regularize']) |
| 138 | |
| 139 | # force overwrite here. This is a bit of a hack, not happy about it |
| 140 | model['bd'] = bias_init_vector.reshape(1, bias_init_vector.size) |
| 141 | |
| 142 | print 'model init done.' |
| 143 | print 'model has keys: ' + ', '.join(model.keys()) |
| 144 | print 'updating: ' + ', '.join( '%s [%dx%d]' % (k, model[k].shape[0], model[k].shape[1]) for k in misc['update']) |
| 145 | print 'updating: ' + ', '.join( '%s [%dx%d]' % (k, model[k].shape[0], model[k].shape[1]) for k in misc['regularize']) |
| 146 | print 'number of learnable parameters total: %d' % (sum(model[k].shape[0] * model[k].shape[1] for k in misc['update']), ) |
| 147 | |
| 148 | if params.get('init_model_from', ''): |
| 149 | # load checkpoint |
| 150 | checkpoint = pickle.load(open(params['init_model_from'], 'rb')) |
| 151 | model = checkpoint['model'] # overwrite the model |
| 152 | |
| 153 | # initialize the Solver and the cost function |
| 154 | solver = Solver() |
| 155 | def costfun(batch, model): |
| 156 | # wrap the cost function to abstract some things away from the Solver |
| 157 | return RNNGenCost(batch, model, params, misc) |
| 158 | |
| 159 | # calculate how many iterations we need |
| 160 | num_sentences_total = dp.getSplitSize('train', ofwhat = 'sentences') |
| 161 | num_iters_one_epoch = num_sentences_total / batch_size |
| 162 | max_iters = max_epochs * num_iters_one_epoch |
| 163 | eval_period_in_epochs = params['eval_period'] |
| 164 | eval_period_in_iters = max(1, int(num_iters_one_epoch * eval_period_in_epochs)) |
| 165 | abort = False |
| 166 | top_val_ppl2 = -1 |
| 167 | smooth_train_ppl2 = len(misc['ixtoword']) # initially size of dictionary of confusion |
| 168 | val_ppl2 = len(misc['ixtoword']) |
| 169 | last_status_write_time = 0 # for writing worker job status reports |
| 170 | json_worker_status = {} |
| 171 | json_worker_status['params'] = params |
| 172 | json_worker_status['history'] = [] |
| 173 | for it in xrange(max_iters): |
| 174 | if abort: break |
no test coverage detected