MCPcopy Index your code
hub / github.com/lazyprogrammer/machine_learning_examples / fit

Method fit

rnn_class/srn_parity_tf.py:21–95  ·  view source on GitHub ↗
(self, X, Y, learning_rate=1.0, mu=0.99, reg=1.0, activation=tf.tanh, epochs=100, show_fig=False)

Source from the content-addressed store, hash-verified

19 self.M = M # hidden layer size
20
21 def fit(self, X, Y, learning_rate=1.0, mu=0.99, reg=1.0, activation=tf.tanh, epochs=100, show_fig=False):
22 N, T, D = X.shape
23 K = len(set(Y.flatten()))
24 M = self.M
25 self.f = activation
26
27 # initial weights
28 Wx = init_weight(D, M).astype(np.float32)
29 Wh = init_weight(M, M).astype(np.float32)
30 bh = np.zeros(M, dtype=np.float32)
31 h0 = np.zeros(M, dtype=np.float32)
32 Wo = init_weight(M, K).astype(np.float32)
33 bo = np.zeros(K, dtype=np.float32)
34
35 # make them theano shared
36 self.Wx = tf.Variable(Wx)
37 self.Wh = tf.Variable(Wh)
38 self.bh = tf.Variable(bh)
39 self.h0 = tf.Variable(h0)
40 self.Wo = tf.Variable(Wo)
41 self.bo = tf.Variable(bo)
42
43 tfX = tf.placeholder(tf.float32, shape=(T, D), name='X')
44 tfY = tf.placeholder(tf.int32, shape=(T,), name='Y')
45
46 XWx = tf.matmul(tfX, self.Wx)
47
48 def recurrence(h_t1, xw_t):
49 # matmul() only works with 2-D objects
50 # we want to return a 1-D object of size M
51 # so that the final result is T x M
52 # not T x 1 x M
53 h_t = self.f(xw_t + tf.matmul(tf.reshape(h_t1, (1, M)), self.Wh) + self.bh)
54 return tf.reshape(h_t, (M,))
55
56 h = tf.scan(
57 fn=recurrence,
58 elems=XWx,
59 initializer=self.h0,
60 )
61
62 logits = tf.matmul(h, self.Wo) + self.bo
63
64 cost = tf.reduce_mean(
65 tf.nn.sparse_softmax_cross_entropy_with_logits(
66 labels=tfY,
67 logits=logits,
68 )
69 )
70
71 predict_op = tf.argmax(logits, 1)
72 train_op = tf.train.AdamOptimizer(1e-2).minimize(cost)
73
74 init = tf.global_variables_initializer()
75 with tf.Session() as session:
76 session.run(init)
77
78 costs = []

Callers 1

parityFunction · 0.95

Calls 2

init_weightFunction · 0.90
runMethod · 0.45

Tested by

no test coverage detected