MCPcopy Index your code
hub / github.com/lazyprogrammer/machine_learning_examples / RNN

Class RNN

rnn_class/batch_wiki.py:24–129  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

22
23
24class RNN:
25 def __init__(self, D, hidden_layer_sizes, V):
26 self.hidden_layer_sizes = hidden_layer_sizes
27 self.D = D
28 self.V = V
29
30 def fit(self, X, learning_rate=1e-4, mu=0.99, epochs=10, batch_sz=100, show_fig=True, activation=T.nnet.relu, RecurrentUnit=LSTM):
31 D = self.D
32 V = self.V
33 N = len(X)
34
35 We = init_weight(V, D)
36 self.hidden_layers = []
37 Mi = D
38 for Mo in self.hidden_layer_sizes:
39 ru = RecurrentUnit(Mi, Mo, activation)
40 self.hidden_layers.append(ru)
41 Mi = Mo
42
43 Wo = init_weight(Mi, V)
44 bo = np.zeros(V)
45
46 self.We = theano.shared(We)
47 self.Wo = theano.shared(Wo)
48 self.bo = theano.shared(bo)
49 self.params = [self.We, self.Wo, self.bo]
50 for ru in self.hidden_layers:
51 self.params += ru.params
52
53 thX = T.ivector('X') # will represent multiple batches concatenated
54 thY = T.ivector('Y') # represents next word
55 thStartPoints = T.ivector('start_points')
56
57 Z = self.We[thX]
58 for ru in self.hidden_layers:
59 Z = ru.output(Z, thStartPoints)
60 py_x = T.nnet.softmax(Z.dot(self.Wo) + self.bo)
61 prediction = T.argmax(py_x, axis=1)
62
63 cost = -T.mean(T.log(py_x[T.arange(thY.shape[0]), thY]))
64 grads = T.grad(cost, self.params)
65 dparams = [theano.shared(p.get_value()*0) for p in self.params]
66
67 updates = [
68 (p, p + mu*dp - learning_rate*g) for p, dp, g in zip(self.params, dparams, grads)
69 ] + [
70 (dp, mu*dp - learning_rate*g) for dp, g in zip(dparams, grads)
71 ]
72
73 # self.predict_op = theano.function(inputs=[thX, thStartPoints], outputs=prediction)
74 self.train_op = theano.function(
75 inputs=[thX, thY, thStartPoints],
76 outputs=[cost, prediction],
77 updates=updates
78 )
79
80 costs = []
81 n_batches = N // batch_sz

Callers 1

train_wikipediaFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected