MCPcopy Index your code
hub / github.com/lazyprogrammer/machine_learning_examples / fit

Method fit

rnn_class/srn_language_tf.py:106–159  ·  view source on GitHub ↗
(self, X, epochs=500, show_fig=False)

Source from the content-addressed store, hash-verified

104
105
106 def fit(self, X, epochs=500, show_fig=False):
107 N = len(X)
108 D = self.D
109 M = self.M
110 V = self.V
111
112 # initial weights
113 We = init_weight(V, D).astype(np.float32)
114 Wx = init_weight(D, M).astype(np.float32)
115 Wh = init_weight(M, M).astype(np.float32)
116 bh = np.zeros(M).astype(np.float32)
117 h0 = np.zeros(M).astype(np.float32)
118 Wo = init_weight(M, V).astype(np.float32)
119 bo = np.zeros(V).astype(np.float32)
120
121 # build tensorflow functions
122 self.build(We, Wx, Wh, bh, h0, Wo, bo)
123
124 # sentence input:
125 # [START, w1, w2, ..., wn]
126 # sentence target:
127 # [w1, w2, w3, ..., END]
128
129 costs = []
130 n_total = sum((len(sentence)+1) for sentence in X)
131 for i in range(epochs):
132 X = shuffle(X)
133 n_correct = 0
134 cost = 0
135 for j in range(N):
136 # problem! many words --> END token are overrepresented
137 # result: generated lines will be very short
138 # we will try to fix in a later iteration
139 # BAD! magic numbers 0 and 1...
140 input_sequence = [0] + X[j]
141 output_sequence = X[j] + [1]
142
143 # we set 0 to start and 1 to end
144 _, c, p = self.session.run(
145 (self.train_op, self.cost, self.predict_op),
146 feed_dict={self.tfX: input_sequence, self.tfY: output_sequence}
147 )
148 # print "p:", p
149 cost += c
150 # print "j:", j, "c:", c/len(X[j]+1)
151 for pj, xj in zip(p, output_sequence):
152 if pj == xj:
153 n_correct += 1
154 print("i:", i, "cost:", cost, "correct rate:", (float(n_correct)/n_total))
155 costs.append(cost)
156
157 if show_fig:
158 plt.plot(costs)
159 plt.show()
160
161 def predict(self, prev_words):
162 # don't use argmax, so that we can sample

Callers 1

train_poetryFunction · 0.95

Calls 3

buildMethod · 0.95
init_weightFunction · 0.90
runMethod · 0.45

Tested by

no test coverage detected