MCPcopy Index your code
hub / github.com/lazyprogrammer/machine_learning_examples / fit

Method fit

rnn_class/batch_parity.py:22–152  ·  view source on GitHub ↗
(self, X, Y, batch_sz=20, learning_rate=1.0, mu=0.99, reg=1.0, activation=T.tanh, epochs=100, show_fig=False)

Source from the content-addressed store, hash-verified

20 self.M = M # hidden layer size
21
22 def fit(self, X, Y, batch_sz=20, learning_rate=1.0, mu=0.99, reg=1.0, activation=T.tanh, epochs=100, show_fig=False):
23 D = X[0].shape[1] # X is of size N x T(n) x D
24 K = len(set(Y.flatten()))
25 N = len(Y)
26 M = self.M
27 self.f = activation
28
29 # initial weights
30 Wx = init_weight(D, M)
31 Wh = init_weight(M, M)
32 bh = np.zeros(M)
33 h0 = np.zeros(M)
34 Wo = init_weight(M, K)
35 bo = np.zeros(K)
36
37 # make them theano shared
38 self.Wx = theano.shared(Wx)
39 self.Wh = theano.shared(Wh)
40 self.bh = theano.shared(bh)
41 self.h0 = theano.shared(h0)
42 self.Wo = theano.shared(Wo)
43 self.bo = theano.shared(bo)
44 self.params = [self.Wx, self.Wh, self.bh, self.h0, self.Wo, self.bo]
45
46 thX = T.fmatrix('X') # will represent multiple batches concatenated
47 thY = T.ivector('Y')
48 thStartPoints = T.ivector('start_points')
49
50 XW = thX.dot(self.Wx)
51
52 # startPoints will contain 1 where a sequence starts and 0 otherwise
53 # Ex. if I have 3 sequences: [[1,2,3], [4,5], [6,7,8]]
54 # Then I will concatenate these into one X: [1,2,3,4,5,6,7,8]
55 # And startPoints will be [1,0,0,1,0,1,0,0]
56
57 # One possible solution: loop through index
58 # def recurrence(t, h_t1, XW, h0, startPoints):
59 # # returns h(t)
60
61 # # if at a boundary, state should be h0
62 # h_t = T.switch(
63 # T.eq(startPoints[t], 1),
64 # self.f(XW[t] + h0.dot(self.Wh) + self.bh),
65 # self.f(XW[t] + h_t1.dot(self.Wh) + self.bh)
66 # )
67 # return h_t
68
69 # h, _ = theano.scan(
70 # fn=recurrence,
71 # outputs_info=[self.h0],
72 # sequences=T.arange(XW.shape[0]),
73 # non_sequences=[XW, self.h0, thStartPoints],
74 # n_steps=XW.shape[0],
75 # )
76
77 # other solution - loop through all sequences simultaneously
78 def recurrence(xw_t, is_start, h_t1, h0):
79 # if at a boundary, state should be h0

Callers 1

parityFunction · 0.95

Calls 2

init_weightFunction · 0.90
gradMethod · 0.45

Tested by

no test coverage detected