MCPcopy Index your code
hub / github.com/MingchaoZhu/DeepLearning / RegressionTree

Class RegressionTree

code/chapter5.py:489–527  ·  view source on GitHub ↗

回归树,在决策书节点选择计算MSE/方差降低,在叶子节点选择均值。

Source from the content-addressed store, hash-verified

487
488
489class RegressionTree(DecisionTree):
490 """
491 回归树,在决策书节点选择计算MSE/方差降低,在叶子节点选择均值。
492 """
493 def _calculate_mse(self, y, y1, y2):
494 """
495 计算MSE降低
496 """
497 mse_tot = calculate_mse(y)
498 mse_1 = calculate_mse(y1)
499 mse_2 = calculate_mse(y2)
500 frac_1 = len(y1) / len(y)
501 frac_2 = len(y2) / len(y)
502 mse_reduction = mse_tot - (frac_1 * mse_1 + frac_2 * mse_2)
503 return mse_reduction
504
505 def _calculate_variance_reduction(self, y, y1, y2):
506 """
507 计算方差降低
508 """
509 var_tot = calculate_variance(y)
510 var_1 = calculate_variance(y1)
511 var_2 = calculate_variance(y2)
512 frac_1 = len(y1) / len(y)
513 frac_2 = len(y2) / len(y)
514 variance_reduction = var_tot - (frac_1 * var_1 + frac_2 * var_2)
515 return sum(variance_reduction)
516
517 def _mean_of_y(self, y):
518 """
519 计算均值
520 """
521 value = np.mean(y, axis=0)
522 return value if len(value) > 1 else value[0]
523
524 def fit(self, X, y):
525 self._impurity_calculation = self._calculate_mse
526 self._leaf_value_calculation = self._mean_of_y
527 super(RegressionTree, self).fit(X, y)
528
529
530########-----PCA------#########

Callers 1

fitMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected