MCPcopy
hub / github.com/MingchaoZhu/DeepLearning / fit

Method fit

code/chapter7.py:594–627  ·  view source on GitHub ↗
(self, X, Y)

Source from the content-addressed store, hash-verified

592 self.loss = CrossEntropyLoss()
593
594 def fit(self, X, Y):
595 # 分类问题将 Y 转化为 one-hot 编码
596 if not self.is_regression:
597 Y = to_categorical(Y.flatten())
598 else:
599 Y = Y.reshape(-1, 1) if len(Y.shape) == 1 else Y
600 self.out_dims = Y.shape[1]
601 self.trees = np.empty((self.n_estimators, self.out_dims), dtype=object)
602 Y_pred = np.full(np.shape(Y), np.mean(Y, axis=0))
603 self.weights = np.ones((self.n_estimators, self.out_dims))
604 self.weights[1:, :] *= self.learning_rate
605 # 迭代过程
606 for i in self.progressbar(range(self.n_estimators)):
607 for c in range(self.out_dims):
608 tree = RegressionTree(
609 min_samples_split=self.min_samples_split,
610 min_impurity=self.min_impurity,
611 max_depth=self.max_depth)
612 # 计算损失的梯度,并用梯度进行训练
613 if not self.is_regression:
614 Y_hat = softmax(Y_pred)
615 y, y_pred = Y[:, c], Y_hat[:, c]
616 else:
617 y, y_pred = Y[:, c], Y_pred[:, c]
618 neg_grad = -1 * self.loss.grad(y, y_pred)
619 tree.fit(X, neg_grad)
620 # 用新的基学习器进行预测
621 h_pred = tree.predict(X)
622 # line search
623 if self.line_search == True:
624 self.weights[i, c] *= line_search(y, y_pred, h_pred)
625 # 加法模型中添加基学习器的预测,得到最新迭代下的加法模型预测
626 Y_pred[:, c] += np.multiply(self.weights[i, c], h_pred)
627 self.trees[i, c] = tree
628
629 def predict(self, X):
630 Y_pred = np.zeros((X.shape[0], self.out_dims))

Callers

nothing calls this directly

Calls 7

fitMethod · 0.95
RegressionTreeClass · 0.90
to_categoricalFunction · 0.85
line_searchFunction · 0.85
softmaxFunction · 0.70
gradMethod · 0.45
predictMethod · 0.45

Tested by

no test coverage detected