(self, trees, lr=1e-1, mu=0.9, reg=0.1, epochs=5)
| 58 | self.params = [self.We, self.W1, self.W2, self.Wo] |
| 59 | |
| 60 | def fit(self, trees, lr=1e-1, mu=0.9, reg=0.1, epochs=5): |
| 61 | train_ops = [] |
| 62 | costs = [] |
| 63 | predictions = [] |
| 64 | all_labels = [] |
| 65 | i = 0 |
| 66 | N = len(trees) |
| 67 | print("Compiling ops") |
| 68 | for t in trees: |
| 69 | i += 1 |
| 70 | sys.stdout.write("%d/%d\r" % (i, N)) |
| 71 | sys.stdout.flush() |
| 72 | logits = self.get_output(t) |
| 73 | labels = get_labels(t) |
| 74 | all_labels.append(labels) |
| 75 | |
| 76 | cost = self.get_cost(logits, labels, reg) |
| 77 | costs.append(cost) |
| 78 | |
| 79 | prediction = tf.argmax(input=logits, axis=1) |
| 80 | predictions.append(prediction) |
| 81 | |
| 82 | train_op = tf.compat.v1.train.MomentumOptimizer(lr, mu).minimize(cost) |
| 83 | train_ops.append(train_op) |
| 84 | |
| 85 | # save for later so we don't have to recompile |
| 86 | self.predictions = predictions |
| 87 | self.all_labels = all_labels |
| 88 | self.saver = tf.compat.v1.train.Saver() |
| 89 | |
| 90 | init = tf.compat.v1.initialize_all_variables() |
| 91 | actual_costs = [] |
| 92 | per_epoch_costs = [] |
| 93 | correct_rates = [] |
| 94 | with tf.compat.v1.Session() as session: |
| 95 | session.run(init) |
| 96 | |
| 97 | for i in range(epochs): |
| 98 | t0 = datetime.now() |
| 99 | |
| 100 | train_ops, costs, predictions, all_labels = shuffle(train_ops, costs, predictions, all_labels) |
| 101 | epoch_cost = 0 |
| 102 | n_correct = 0 |
| 103 | n_total = 0 |
| 104 | j = 0 |
| 105 | N = len(train_ops) |
| 106 | for train_op, cost, prediction, labels in zip(train_ops, costs, predictions, all_labels): |
| 107 | _, c, p = session.run([train_op, cost, prediction]) |
| 108 | epoch_cost += c |
| 109 | actual_costs.append(c) |
| 110 | n_correct += np.sum(p == labels) |
| 111 | n_total += len(labels) |
| 112 | |
| 113 | j += 1 |
| 114 | if j % 10 == 0: |
| 115 | sys.stdout.write("j: %d, N: %d, c: %f\r" % (j, N, c)) |
| 116 | sys.stdout.flush() |
| 117 |
no test coverage detected