MCPcopy
hub / github.com/MingchaoZhu/DeepLearning / XGBoostRegressionTree

Class XGBoostRegressionTree

code/chapter7.py:676–723  ·  view source on GitHub ↗

XGBoost 回归树。此处基于第五章介绍的决策树,故采用贪心算法找到特征上分裂点 (枚举特征上所有可能的分裂点)。

Source from the content-addressed store, hash-verified

674
675#####----XGBoost----#######
676class XGBoostRegressionTree(DecisionTree):
677 """
678 XGBoost 回归树。此处基于第五章介绍的决策树,故采用贪心算法找到特征上分裂点 (枚举特征上所有可能的分裂点)。
679 """
680 def __init__(self, min_samples_split=2, min_impurity=1e-7,
681 max_depth=float("inf"), loss=None, gamma=0., lambd=0.):
682 super(XGBoostRegressionTree, self).__init__(min_impurity=min_impurity,
683 min_samples_split=min_samples_split,
684 max_depth=max_depth)
685 self.gamma = gamma # 叶子节点的数目的惩罚系数
686 self.lambd = lambd # 叶子节点的权重的惩罚系数
687 self.loss = loss # 损失函数
688
689 def _split(self, y):
690 # y 包含 y_true 在左半列,y_pred 在右半列
691 col = int(np.shape(y)[1]/2)
692 y, y_pred = y[:, :col], y[:, col:]
693 return y, y_pred
694
695 def _gain(self, y, y_pred):
696 # 计算信息
697 nominator = np.power((y * self.loss.grad(y, y_pred)).sum(), 2)
698 denominator = self.loss.hess(y, y_pred).sum()
699 return nominator / (denominator + self.lambd)
700
701 def _gain_by_taylor(self, y, y1, y2):
702 # 分割为左子树和右子树
703 y, y_pred = self._split(y)
704 y1, y1_pred = self._split(y1)
705 y2, y2_pred = self._split(y2)
706 true_gain = self._gain(y1, y1_pred)
707 false_gain = self._gain(y2, y2_pred)
708 gain = self._gain(y, y_pred)
709 # 计算信息增益
710 return 0.5 * (true_gain + false_gain - gain) - self.gamma
711
712 def _approximate_update(self, y):
713 y, y_pred = self._split(y)
714 # 计算叶节点权重
715 gradient = self.loss.grad(y, y_pred).sum()
716 hessian = self.loss.hess(y, y_pred).sum()
717 leaf_approximation = -gradient / (hessian + self.lambd)
718 return leaf_approximation
719
720 def fit(self, X, y):
721 self._impurity_calculation = self._gain_by_taylor
722 self._leaf_value_calculation = self._approximate_update
723 super(XGBoostRegressionTree, self).fit(X, y)
724
725
726class XGBoost(object):

Callers 1

fitMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected