solver worries about: - different optimization methods, updates, weight decays - it can also perform gradient check
| 3 | from imagernn.utils import randi |
| 4 | |
| 5 | class Solver: |
| 6 | """ |
| 7 | solver worries about: |
| 8 | - different optimization methods, updates, weight decays |
| 9 | - it can also perform gradient check |
| 10 | """ |
| 11 | def __init__(self): |
| 12 | self.step_cache_ = {} # might need this |
| 13 | self.step_cache2_ = {} # might need this |
| 14 | |
| 15 | def step(self, batch, model, cost_function, **kwargs): |
| 16 | """ |
| 17 | perform a single batch update. Takes as input: |
| 18 | - batch of data (X) |
| 19 | - model (W) |
| 20 | - cost function which takes batch, model |
| 21 | """ |
| 22 | |
| 23 | learning_rate = kwargs.get('learning_rate', 0.0) |
| 24 | update = kwargs.get('update', model.keys()) |
| 25 | grad_clip = kwargs.get('grad_clip', -1) |
| 26 | solver = kwargs.get('solver', 'vanilla') |
| 27 | momentum = kwargs.get('momentum', 0) |
| 28 | smooth_eps = kwargs.get('smooth_eps', 1e-8) |
| 29 | decay_rate = kwargs.get('decay_rate', 0.999) |
| 30 | |
| 31 | if not (solver == 'vanilla' and momentum == 0): |
| 32 | # lazily make sure we initialize step cache if needed |
| 33 | for u in update: |
| 34 | if not u in self.step_cache_: |
| 35 | self.step_cache_[u] = np.zeros(model[u].shape) |
| 36 | if solver == 'adadelta': |
| 37 | self.step_cache2_[u] = np.zeros(model[u].shape) # adadelta needs one more cache |
| 38 | |
| 39 | # compute cost and gradient |
| 40 | cg = cost_function(batch, model) |
| 41 | cost = cg['cost'] |
| 42 | grads = cg['grad'] |
| 43 | stats = cg['stats'] |
| 44 | |
| 45 | # clip gradients if needed, simplest possible version |
| 46 | # todo later: maybe implement the gradient direction conserving version |
| 47 | if grad_clip > 0: |
| 48 | for p in update: |
| 49 | if p in grads: |
| 50 | grads[p] = np.minimum(grads[p], grad_clip) |
| 51 | grads[p] = np.maximum(grads[p], -grad_clip) |
| 52 | |
| 53 | # perform parameter update |
| 54 | for p in update: |
| 55 | if p in grads: |
| 56 | |
| 57 | if solver == 'vanilla': # vanilla sgd, optional with momentum |
| 58 | if momentum > 0: |
| 59 | dx = momentum * self.step_cache_[p] - learning_rate * grads[p] |
| 60 | self.step_cache_[p] = dx |
| 61 | else: |
| 62 | dx = - learning_rate * grads[p] |