MCPcopy
hub / github.com/karpathy/neuraltalk / main

Function main

driver.py:117–262  ·  view source on GitHub ↗
(params)

Source from the content-addressed store, hash-verified

115 return out
116
117def main(params):
118 batch_size = params['batch_size']
119 dataset = params['dataset']
120 word_count_threshold = params['word_count_threshold']
121 do_grad_check = params['do_grad_check']
122 max_epochs = params['max_epochs']
123 host = socket.gethostname() # get computer hostname
124
125 # fetch the data provider
126 dp = getDataProvider(dataset)
127
128 misc = {} # stores various misc items that need to be passed around the framework
129
130 # go over all training sentences and find the vocabulary we want to use, i.e. the words that occur
131 # at least word_count_threshold number of times
132 misc['wordtoix'], misc['ixtoword'], bias_init_vector = preProBuildWordVocab(dp.iterSentences('train'), word_count_threshold)
133
134 # delegate the initialization of the model to the Generator class
135 BatchGenerator = decodeGenerator(params)
136 init_struct = BatchGenerator.init(params, misc)
137 model, misc['update'], misc['regularize'] = (init_struct['model'], init_struct['update'], init_struct['regularize'])
138
139 # force overwrite here. This is a bit of a hack, not happy about it
140 model['bd'] = bias_init_vector.reshape(1, bias_init_vector.size)
141
142 print 'model init done.'
143 print 'model has keys: ' + ', '.join(model.keys())
144 print 'updating: ' + ', '.join( '%s [%dx%d]' % (k, model[k].shape[0], model[k].shape[1]) for k in misc['update'])
145 print 'updating: ' + ', '.join( '%s [%dx%d]' % (k, model[k].shape[0], model[k].shape[1]) for k in misc['regularize'])
146 print 'number of learnable parameters total: %d' % (sum(model[k].shape[0] * model[k].shape[1] for k in misc['update']), )
147
148 if params.get('init_model_from', ''):
149 # load checkpoint
150 checkpoint = pickle.load(open(params['init_model_from'], 'rb'))
151 model = checkpoint['model'] # overwrite the model
152
153 # initialize the Solver and the cost function
154 solver = Solver()
155 def costfun(batch, model):
156 # wrap the cost function to abstract some things away from the Solver
157 return RNNGenCost(batch, model, params, misc)
158
159 # calculate how many iterations we need
160 num_sentences_total = dp.getSplitSize('train', ofwhat = 'sentences')
161 num_iters_one_epoch = num_sentences_total / batch_size
162 max_iters = max_epochs * num_iters_one_epoch
163 eval_period_in_epochs = params['eval_period']
164 eval_period_in_iters = max(1, int(num_iters_one_epoch * eval_period_in_epochs))
165 abort = False
166 top_val_ppl2 = -1
167 smooth_train_ppl2 = len(misc['ixtoword']) # initially size of dictionary of confusion
168 val_ppl2 = len(misc['ixtoword'])
169 last_status_write_time = 0 # for writing worker job status reports
170 json_worker_status = {}
171 json_worker_status['params'] = params
172 json_worker_status['history'] = []
173 for it in xrange(max_iters):
174 if abort: break

Callers 1

driver.pyFile · 0.70

Calls 11

stepMethod · 0.95
gradCheckMethod · 0.95
getDataProviderFunction · 0.90
decodeGeneratorFunction · 0.90
SolverClass · 0.90
eval_splitFunction · 0.90
preProBuildWordVocabFunction · 0.85
iterSentencesMethod · 0.80
getSplitSizeMethod · 0.80
initMethod · 0.45

Tested by

no test coverage detected