(self, X, Y, activation=T.nnet.relu, learning_rate=1e-3, mu=0.0, reg=0, epochs=100, batch_sz=None, print_period=100, show_fig=True)
| 40 | self.hidden_layer_sizes = hidden_layer_sizes |
| 41 | |
| 42 | def fit(self, X, Y, activation=T.nnet.relu, learning_rate=1e-3, mu=0.0, reg=0, epochs=100, batch_sz=None, print_period=100, show_fig=True): |
| 43 | X = X.astype(np.float32) |
| 44 | Y = Y.astype(np.int32) |
| 45 | |
| 46 | # initialize hidden layers |
| 47 | N, D = X.shape |
| 48 | self.layers = [] |
| 49 | M1 = D |
| 50 | for M2 in self.hidden_layer_sizes: |
| 51 | h = HiddenLayer(M1, M2, activation) |
| 52 | self.layers.append(h) |
| 53 | M1 = M2 |
| 54 | |
| 55 | # final layer |
| 56 | K = len(set(Y)) |
| 57 | # print("K:", K) |
| 58 | h = HiddenLayer(M1, K, T.nnet.softmax) |
| 59 | self.layers.append(h) |
| 60 | |
| 61 | if batch_sz is None: |
| 62 | batch_sz = N |
| 63 | |
| 64 | # collect params for later use |
| 65 | self.params = [] |
| 66 | for h in self.layers: |
| 67 | self.params += h.params |
| 68 | |
| 69 | # for momentum |
| 70 | dparams = [theano.shared(np.zeros_like(p.get_value())) for p in self.params] |
| 71 | |
| 72 | # set up theano functions and variables |
| 73 | thX = T.matrix('X') |
| 74 | thY = T.ivector('Y') |
| 75 | p_y_given_x = self.forward(thX) |
| 76 | |
| 77 | rcost = reg*T.mean([(p*p).sum() for p in self.params]) |
| 78 | cost = -T.mean(T.log(p_y_given_x[T.arange(thY.shape[0]), thY])) #+ rcost |
| 79 | prediction = T.argmax(p_y_given_x, axis=1) |
| 80 | grads = T.grad(cost, self.params) |
| 81 | |
| 82 | # momentum only |
| 83 | updates = [ |
| 84 | (p, p + mu*dp - learning_rate*g) for p, dp, g in zip(self.params, dparams, grads) |
| 85 | ] + [ |
| 86 | (dp, mu*dp - learning_rate*g) for dp, g in zip(dparams, grads) |
| 87 | ] |
| 88 | |
| 89 | train_op = theano.function( |
| 90 | inputs=[thX, thY], |
| 91 | outputs=[cost, prediction], |
| 92 | updates=updates, |
| 93 | ) |
| 94 | |
| 95 | self.predict_op = theano.function( |
| 96 | inputs=[thX], |
| 97 | outputs=prediction, |
| 98 | ) |
| 99 |
no test coverage detected