MCPcopy Index your code
hub / github.com/lazyprogrammer/machine_learning_examples / fit

Method fit

rnn_class/srn_parity.py:22–101  ·  view source on GitHub ↗
(self, X, Y, learning_rate=0.1, mu=0.99, reg=1.0, activation=T.tanh, epochs=100, show_fig=False)

Source from the content-addressed store, hash-verified

20 self.M = M # hidden layer size
21
22 def fit(self, X, Y, learning_rate=0.1, mu=0.99, reg=1.0, activation=T.tanh, epochs=100, show_fig=False):
23 D = X[0].shape[1] # X is of size N x T(n) x D
24 K = len(set(Y.flatten()))
25 N = len(Y)
26 M = self.M
27 self.f = activation
28
29 # initial weights
30 Wx = init_weight(D, M)
31 Wh = init_weight(M, M)
32 bh = np.zeros(M)
33 h0 = np.zeros(M)
34 Wo = init_weight(M, K)
35 bo = np.zeros(K)
36
37 # make them theano shared
38 self.Wx = theano.shared(Wx)
39 self.Wh = theano.shared(Wh)
40 self.bh = theano.shared(bh)
41 self.h0 = theano.shared(h0)
42 self.Wo = theano.shared(Wo)
43 self.bo = theano.shared(bo)
44 self.params = [self.Wx, self.Wh, self.bh, self.h0, self.Wo, self.bo]
45
46 thX = T.fmatrix('X')
47 thY = T.ivector('Y')
48
49 def recurrence(x_t, h_t1):
50 # returns h(t), y(t)
51 h_t = self.f(x_t.dot(self.Wx) + h_t1.dot(self.Wh) + self.bh)
52 y_t = T.nnet.softmax(h_t.dot(self.Wo) + self.bo)
53 return h_t, y_t
54
55 [h, y], _ = theano.scan(
56 fn=recurrence,
57 outputs_info=[self.h0, None],
58 sequences=thX,
59 n_steps=thX.shape[0],
60 )
61
62 py_x = y[:, 0, :]
63 prediction = T.argmax(py_x, axis=1)
64
65 cost = -T.mean(T.log(py_x[T.arange(thY.shape[0]), thY]))
66 grads = T.grad(cost, self.params)
67 dparams = [theano.shared(p.get_value()*0) for p in self.params]
68
69 updates = [
70 (p, p + mu*dp - learning_rate*g) for p, dp, g in zip(self.params, dparams, grads)
71 ] + [
72 (dp, mu*dp - learning_rate*g) for dp, g in zip(dparams, grads)
73 ]
74
75 self.predict_op = theano.function(inputs=[thX], outputs=prediction)
76 self.train_op = theano.function(
77 inputs=[thX, thY],
78 outputs=[cost, prediction, y],
79 updates=updates

Callers 1

parityFunction · 0.95

Calls 2

init_weightFunction · 0.90
gradMethod · 0.45

Tested by

no test coverage detected