MCPcopy
hub / github.com/karpathy/neuraltalk / gradCheck

Method gradCheck

imagernn/solver.py:89–142  ·  view source on GitHub ↗

perform gradient check. since gradcheck can be tricky (especially with relus involved) this function prints to console for visual inspection

(self, batch, model, cost_function, **kwargs)

Source from the content-addressed store, hash-verified

87 return out
88
89 def gradCheck(self, batch, model, cost_function, **kwargs):
90 """
91 perform gradient check.
92 since gradcheck can be tricky (especially with relus involved)
93 this function prints to console for visual inspection
94 """
95
96 num_checks = kwargs.get('num_checks', 10)
97 delta = kwargs.get('delta', 1e-5)
98 rel_error_thr_warning = kwargs.get('rel_error_thr_warning', 1e-2)
99 rel_error_thr_error = kwargs.get('rel_error_thr_error', 1)
100
101 cg = cost_function(batch, model)
102
103 print 'running gradient check...'
104 for p in model.keys():
105 print 'checking gradient on parameter %s of shape %s...' % (p, `model[p].shape`)
106 mat = model[p]
107
108 s0 = cg['grad'][p].shape
109 s1 = mat.shape
110 assert s0 == s1, 'Error dims dont match: %s and %s.' % (`s0`, `s1`)
111
112 for i in xrange(num_checks):
113 ri = randi(mat.size)
114
115 # evluate cost at [x + delta] and [x - delta]
116 old_val = mat.flat[ri]
117 mat.flat[ri] = old_val + delta
118 cg0 = cost_function(batch, model)
119 mat.flat[ri] = old_val - delta
120 cg1 = cost_function(batch, model)
121 mat.flat[ri] = old_val # reset old value for this parameter
122
123 # fetch both numerical and analytic gradient
124 grad_analytic = cg['grad'][p].flat[ri]
125 grad_numerical = (cg0['cost']['total_cost'] - cg1['cost']['total_cost']) / ( 2 * delta )
126
127 # compare them
128 if grad_numerical == 0 and grad_analytic == 0:
129 rel_error = 0 # both are zero, OK.
130 status = 'OK'
131 elif abs(grad_numerical) < 1e-7 and abs(grad_analytic) < 1e-7:
132 rel_error = 0 # not enough precision to check this
133 status = 'VAL SMALL WARNING'
134 else:
135 rel_error = abs(grad_analytic - grad_numerical) / abs(grad_numerical + grad_analytic)
136 status = 'OK'
137 if rel_error > rel_error_thr_warning: status = 'WARNING'
138 if rel_error > rel_error_thr_error: status = '!!!!! NOTOK'
139
140 # print stats
141 print '%s checking param %s index %8d (val = %+8f), analytic = %+8f, numerical = %+8f, relative error = %+8f' \
142 % (status, p, ri, old_val, grad_analytic, grad_numerical, rel_error)
143
144
145

Callers 1

mainFunction · 0.95

Calls 1

randiFunction · 0.90

Tested by

no test coverage detected