MCPcopy Index your code
hub / github.com/lazyprogrammer/machine_learning_examples / fit

Method fit

nlp_class2/recursive_tensorflow.py:60–138  ·  view source on GitHub ↗
(self, trees, lr=1e-1, mu=0.9, reg=0.1, epochs=5)

Source from the content-addressed store, hash-verified

58 self.params = [self.We, self.W1, self.W2, self.Wo]
59
60 def fit(self, trees, lr=1e-1, mu=0.9, reg=0.1, epochs=5):
61 train_ops = []
62 costs = []
63 predictions = []
64 all_labels = []
65 i = 0
66 N = len(trees)
67 print("Compiling ops")
68 for t in trees:
69 i += 1
70 sys.stdout.write("%d/%d\r" % (i, N))
71 sys.stdout.flush()
72 logits = self.get_output(t)
73 labels = get_labels(t)
74 all_labels.append(labels)
75
76 cost = self.get_cost(logits, labels, reg)
77 costs.append(cost)
78
79 prediction = tf.argmax(input=logits, axis=1)
80 predictions.append(prediction)
81
82 train_op = tf.compat.v1.train.MomentumOptimizer(lr, mu).minimize(cost)
83 train_ops.append(train_op)
84
85 # save for later so we don't have to recompile
86 self.predictions = predictions
87 self.all_labels = all_labels
88 self.saver = tf.compat.v1.train.Saver()
89
90 init = tf.compat.v1.initialize_all_variables()
91 actual_costs = []
92 per_epoch_costs = []
93 correct_rates = []
94 with tf.compat.v1.Session() as session:
95 session.run(init)
96
97 for i in range(epochs):
98 t0 = datetime.now()
99
100 train_ops, costs, predictions, all_labels = shuffle(train_ops, costs, predictions, all_labels)
101 epoch_cost = 0
102 n_correct = 0
103 n_total = 0
104 j = 0
105 N = len(train_ops)
106 for train_op, cost, prediction, labels in zip(train_ops, costs, predictions, all_labels):
107 _, c, p = session.run([train_op, cost, prediction])
108 epoch_cost += c
109 actual_costs.append(c)
110 n_correct += np.sum(p == labels)
111 n_total += len(labels)
112
113 j += 1
114 if j % 10 == 0:
115 sys.stdout.write("j: %d, N: %d, c: %f\r" % (j, N, c))
116 sys.stdout.flush()
117

Callers 1

mainFunction · 0.95

Calls 5

get_outputMethod · 0.95
get_costMethod · 0.95
get_labelsFunction · 0.70
runMethod · 0.45
saveMethod · 0.45

Tested by

no test coverage detected