(self, X, epochs=500, show_fig=False)
| 104 | |
| 105 | |
| 106 | def fit(self, X, epochs=500, show_fig=False): |
| 107 | N = len(X) |
| 108 | D = self.D |
| 109 | M = self.M |
| 110 | V = self.V |
| 111 | |
| 112 | # initial weights |
| 113 | We = init_weight(V, D).astype(np.float32) |
| 114 | Wx = init_weight(D, M).astype(np.float32) |
| 115 | Wh = init_weight(M, M).astype(np.float32) |
| 116 | bh = np.zeros(M).astype(np.float32) |
| 117 | h0 = np.zeros(M).astype(np.float32) |
| 118 | Wo = init_weight(M, V).astype(np.float32) |
| 119 | bo = np.zeros(V).astype(np.float32) |
| 120 | |
| 121 | # build tensorflow functions |
| 122 | self.build(We, Wx, Wh, bh, h0, Wo, bo) |
| 123 | |
| 124 | # sentence input: |
| 125 | # [START, w1, w2, ..., wn] |
| 126 | # sentence target: |
| 127 | # [w1, w2, w3, ..., END] |
| 128 | |
| 129 | costs = [] |
| 130 | n_total = sum((len(sentence)+1) for sentence in X) |
| 131 | for i in range(epochs): |
| 132 | X = shuffle(X) |
| 133 | n_correct = 0 |
| 134 | cost = 0 |
| 135 | for j in range(N): |
| 136 | # problem! many words --> END token are overrepresented |
| 137 | # result: generated lines will be very short |
| 138 | # we will try to fix in a later iteration |
| 139 | # BAD! magic numbers 0 and 1... |
| 140 | input_sequence = [0] + X[j] |
| 141 | output_sequence = X[j] + [1] |
| 142 | |
| 143 | # we set 0 to start and 1 to end |
| 144 | _, c, p = self.session.run( |
| 145 | (self.train_op, self.cost, self.predict_op), |
| 146 | feed_dict={self.tfX: input_sequence, self.tfY: output_sequence} |
| 147 | ) |
| 148 | # print "p:", p |
| 149 | cost += c |
| 150 | # print "j:", j, "c:", c/len(X[j]+1) |
| 151 | for pj, xj in zip(p, output_sequence): |
| 152 | if pj == xj: |
| 153 | n_correct += 1 |
| 154 | print("i:", i, "cost:", cost, "correct rate:", (float(n_correct)/n_total)) |
| 155 | costs.append(cost) |
| 156 | |
| 157 | if show_fig: |
| 158 | plt.plot(costs) |
| 159 | plt.show() |
| 160 | |
| 161 | def predict(self, prev_words): |
| 162 | # don't use argmax, so that we can sample |
no test coverage detected