MCPcopy
hub / github.com/MingchaoZhu/DeepLearning / fit

Method fit

code/chapter9.py:1161–1203  ·  view source on GitHub ↗

参数说明: X_train:训练数据 y_train:训练数据标签 n_epochs:epoch 次数 batch_size:每次 epoch 的 batch size verbose:是否每个 batch 输出损失 epo_verbose:是否每个 epoch 输出损失

(self, X_train, y_train, n_epochs=20, batch_size=64, verbose=False, epo_verbose=True)

Source from the content-addressed store, hash-verified

1159 v.flush_gradients()
1160
1161 def fit(self, X_train, y_train, n_epochs=20, batch_size=64, verbose=False, epo_verbose=True):
1162 """
1163 参数说明:
1164 X_train:训练数据
1165 y_train:训练数据标签
1166 n_epochs:epoch 次数
1167 batch_size:每次 epoch 的 batch size
1168 verbose:是否每个 batch 输出损失
1169 epo_verbose:是否每个 epoch 输出损失
1170 """
1171 self.verbose = verbose
1172 self.n_epochs = n_epochs
1173 self.batch_size = batch_size
1174
1175 if not self.is_initialized:
1176 self.n_features = X_train.shape[1]
1177 self._set_params()
1178
1179 prev_loss = np.inf
1180 for i in range(n_epochs):
1181 loss, epoch_start = 0.0, time.time()
1182 batch_generator, n_batch = minibatch(X_train, self.batch_size, shuffle=True)
1183
1184 for j, batch_idx in enumerate(batch_generator):
1185 batch_len, batch_start = len(batch_idx), time.time()
1186 X_batch, y_batch = X_train[batch_idx], y_train[batch_idx]
1187 out, _ = self.forward(X_batch)
1188 y_pred_batch = softmax(out)
1189 batch_loss = self.loss(y_batch, y_pred_batch)
1190 grad = self.loss.grad(y_batch, y_pred_batch)
1191 _, _ = self.backward(grad)
1192 self.update()
1193 loss += batch_loss
1194
1195 if self.verbose:
1196 fstr = "\t[Batch {}/{}] Train loss: {:.3f} ({:.1f}s/batch)"
1197 print(fstr.format(j + 1, n_batch, batch_loss, time.time() - batch_start))
1198
1199 loss /= n_batch
1200 if epo_verbose:
1201 fstr = "[Epoch {}] Avg. loss: {:.3f} Delta: {:.3f} ({:.2f}m/epoch)"
1202 print(fstr.format(i + 1, loss, prev_loss - loss, (time.time() - epoch_start) / 60.0))
1203 prev_loss = loss
1204
1205 def evaluate(self, X_test, y_test, batch_size=128):
1206 acc = 0.0

Callers

nothing calls this directly

Calls 8

_set_paramsMethod · 0.95
forwardMethod · 0.95
backwardMethod · 0.95
updateMethod · 0.95
minibatchFunction · 0.90
softmaxFunction · 0.90
lossMethod · 0.45
gradMethod · 0.45

Tested by

no test coverage detected