(self, X, learning_rate=1e-5, mu=0.99, epochs=10, show_fig=True, activation=T.nnet.relu, RecurrentUnit=GRU, normalize=True)
| 29 | self.V = V |
| 30 | |
| 31 | def fit(self, X, learning_rate=1e-5, mu=0.99, epochs=10, show_fig=True, activation=T.nnet.relu, RecurrentUnit=GRU, normalize=True): |
| 32 | D = self.D |
| 33 | V = self.V |
| 34 | N = len(X) |
| 35 | |
| 36 | We = init_weight(V, D) |
| 37 | self.hidden_layers = [] |
| 38 | Mi = D |
| 39 | for Mo in self.hidden_layer_sizes: |
| 40 | ru = RecurrentUnit(Mi, Mo, activation) |
| 41 | self.hidden_layers.append(ru) |
| 42 | Mi = Mo |
| 43 | |
| 44 | Wo = init_weight(Mi, V) |
| 45 | bo = np.zeros(V) |
| 46 | |
| 47 | self.We = theano.shared(We) |
| 48 | self.Wo = theano.shared(Wo) |
| 49 | self.bo = theano.shared(bo) |
| 50 | self.params = [self.Wo, self.bo] |
| 51 | for ru in self.hidden_layers: |
| 52 | self.params += ru.params |
| 53 | |
| 54 | thX = T.ivector('X') |
| 55 | thY = T.ivector('Y') |
| 56 | |
| 57 | Z = self.We[thX] |
| 58 | for ru in self.hidden_layers: |
| 59 | Z = ru.output(Z) |
| 60 | py_x = T.nnet.softmax(Z.dot(self.Wo) + self.bo) |
| 61 | |
| 62 | prediction = T.argmax(py_x, axis=1) |
| 63 | # let's return py_x too so we can draw a sample instead |
| 64 | self.predict_op = theano.function( |
| 65 | inputs=[thX], |
| 66 | outputs=[py_x, prediction], |
| 67 | allow_input_downcast=True, |
| 68 | ) |
| 69 | |
| 70 | cost = -T.mean(T.log(py_x[T.arange(thY.shape[0]), thY])) |
| 71 | grads = T.grad(cost, self.params) |
| 72 | dparams = [theano.shared(p.get_value()*0) for p in self.params] |
| 73 | |
| 74 | dWe = theano.shared(self.We.get_value()*0) |
| 75 | gWe = T.grad(cost, self.We) |
| 76 | dWe_update = mu*dWe - learning_rate*gWe |
| 77 | We_update = self.We + dWe_update |
| 78 | if normalize: |
| 79 | We_update /= We_update.norm(2) |
| 80 | |
| 81 | updates = [ |
| 82 | (p, p + mu*dp - learning_rate*g) for p, dp, g in zip(self.params, dparams, grads) |
| 83 | ] + [ |
| 84 | (dp, mu*dp - learning_rate*g) for dp, g in zip(dparams, grads) |
| 85 | ] + [ |
| 86 | (self.We, We_update), (dWe, dWe_update) |
| 87 | ] |
| 88 |
no test coverage detected