(self, X, Y, learning_rate=1.0, mu=0.99, reg=1.0, activation=tf.tanh, epochs=100, show_fig=False)
| 19 | self.M = M # hidden layer size |
| 20 | |
| 21 | def fit(self, X, Y, learning_rate=1.0, mu=0.99, reg=1.0, activation=tf.tanh, epochs=100, show_fig=False): |
| 22 | N, T, D = X.shape |
| 23 | K = len(set(Y.flatten())) |
| 24 | M = self.M |
| 25 | self.f = activation |
| 26 | |
| 27 | # initial weights |
| 28 | Wx = init_weight(D, M).astype(np.float32) |
| 29 | Wh = init_weight(M, M).astype(np.float32) |
| 30 | bh = np.zeros(M, dtype=np.float32) |
| 31 | h0 = np.zeros(M, dtype=np.float32) |
| 32 | Wo = init_weight(M, K).astype(np.float32) |
| 33 | bo = np.zeros(K, dtype=np.float32) |
| 34 | |
| 35 | # make them theano shared |
| 36 | self.Wx = tf.Variable(Wx) |
| 37 | self.Wh = tf.Variable(Wh) |
| 38 | self.bh = tf.Variable(bh) |
| 39 | self.h0 = tf.Variable(h0) |
| 40 | self.Wo = tf.Variable(Wo) |
| 41 | self.bo = tf.Variable(bo) |
| 42 | |
| 43 | tfX = tf.placeholder(tf.float32, shape=(T, D), name='X') |
| 44 | tfY = tf.placeholder(tf.int32, shape=(T,), name='Y') |
| 45 | |
| 46 | XWx = tf.matmul(tfX, self.Wx) |
| 47 | |
| 48 | def recurrence(h_t1, xw_t): |
| 49 | # matmul() only works with 2-D objects |
| 50 | # we want to return a 1-D object of size M |
| 51 | # so that the final result is T x M |
| 52 | # not T x 1 x M |
| 53 | h_t = self.f(xw_t + tf.matmul(tf.reshape(h_t1, (1, M)), self.Wh) + self.bh) |
| 54 | return tf.reshape(h_t, (M,)) |
| 55 | |
| 56 | h = tf.scan( |
| 57 | fn=recurrence, |
| 58 | elems=XWx, |
| 59 | initializer=self.h0, |
| 60 | ) |
| 61 | |
| 62 | logits = tf.matmul(h, self.Wo) + self.bo |
| 63 | |
| 64 | cost = tf.reduce_mean( |
| 65 | tf.nn.sparse_softmax_cross_entropy_with_logits( |
| 66 | labels=tfY, |
| 67 | logits=logits, |
| 68 | ) |
| 69 | ) |
| 70 | |
| 71 | predict_op = tf.argmax(logits, 1) |
| 72 | train_op = tf.train.AdamOptimizer(1e-2).minimize(cost) |
| 73 | |
| 74 | init = tf.global_variables_initializer() |
| 75 | with tf.Session() as session: |
| 76 | session.run(init) |
| 77 | |
| 78 | costs = [] |
no test coverage detected