(self, X, Y, learning_rate=0.1, mu=0.99, reg=1.0, activation=T.tanh, epochs=100, show_fig=False)
| 20 | self.M = M # hidden layer size |
| 21 | |
| 22 | def fit(self, X, Y, learning_rate=0.1, mu=0.99, reg=1.0, activation=T.tanh, epochs=100, show_fig=False): |
| 23 | D = X[0].shape[1] # X is of size N x T(n) x D |
| 24 | K = len(set(Y.flatten())) |
| 25 | N = len(Y) |
| 26 | M = self.M |
| 27 | self.f = activation |
| 28 | |
| 29 | # initial weights |
| 30 | Wx = init_weight(D, M) |
| 31 | Wh = init_weight(M, M) |
| 32 | bh = np.zeros(M) |
| 33 | h0 = np.zeros(M) |
| 34 | Wo = init_weight(M, K) |
| 35 | bo = np.zeros(K) |
| 36 | |
| 37 | # make them theano shared |
| 38 | self.Wx = theano.shared(Wx) |
| 39 | self.Wh = theano.shared(Wh) |
| 40 | self.bh = theano.shared(bh) |
| 41 | self.h0 = theano.shared(h0) |
| 42 | self.Wo = theano.shared(Wo) |
| 43 | self.bo = theano.shared(bo) |
| 44 | self.params = [self.Wx, self.Wh, self.bh, self.h0, self.Wo, self.bo] |
| 45 | |
| 46 | thX = T.fmatrix('X') |
| 47 | thY = T.ivector('Y') |
| 48 | |
| 49 | def recurrence(x_t, h_t1): |
| 50 | # returns h(t), y(t) |
| 51 | h_t = self.f(x_t.dot(self.Wx) + h_t1.dot(self.Wh) + self.bh) |
| 52 | y_t = T.nnet.softmax(h_t.dot(self.Wo) + self.bo) |
| 53 | return h_t, y_t |
| 54 | |
| 55 | [h, y], _ = theano.scan( |
| 56 | fn=recurrence, |
| 57 | outputs_info=[self.h0, None], |
| 58 | sequences=thX, |
| 59 | n_steps=thX.shape[0], |
| 60 | ) |
| 61 | |
| 62 | py_x = y[:, 0, :] |
| 63 | prediction = T.argmax(py_x, axis=1) |
| 64 | |
| 65 | cost = -T.mean(T.log(py_x[T.arange(thY.shape[0]), thY])) |
| 66 | grads = T.grad(cost, self.params) |
| 67 | dparams = [theano.shared(p.get_value()*0) for p in self.params] |
| 68 | |
| 69 | updates = [ |
| 70 | (p, p + mu*dp - learning_rate*g) for p, dp, g in zip(self.params, dparams, grads) |
| 71 | ] + [ |
| 72 | (dp, mu*dp - learning_rate*g) for dp, g in zip(dparams, grads) |
| 73 | ] |
| 74 | |
| 75 | self.predict_op = theano.function(inputs=[thX], outputs=prediction) |
| 76 | self.train_op = theano.function( |
| 77 | inputs=[thX, thY], |
| 78 | outputs=[cost, prediction, y], |
| 79 | updates=updates |
no test coverage detected