MCPcopy
hub / github.com/ddbourgin/numpy-ml / GradientBoostedDecisionTree

Class GradientBoostedDecisionTree

numpy_ml/trees/gbdt.py:18–181  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

16
17
18class GradientBoostedDecisionTree:
19 def __init__(
20 self,
21 n_iter,
22 max_depth=None,
23 classifier=True,
24 learning_rate=1,
25 loss="crossentropy",
26 step_size="constant",
27 ):
28 """
29 A gradient boosted ensemble of decision trees.
30
31 Notes
32 -----
33 Gradient boosted machines (GBMs) fit an ensemble of `m` weak learners such that:
34
35 .. math::
36
37 f_m(X) = b(X) + \eta w_1 g_1 + \ldots + \eta w_m g_m
38
39 where `b` is a fixed initial estimate for the targets, :math:`\eta` is
40 a learning rate parameter, and :math:`w_{\cdot}` and :math:`g_{\cdot}`
41 denote the weights and learner predictions for subsequent fits.
42
43 We fit each `w` and `g` iteratively using a greedy strategy so that at each
44 iteration `i`,
45
46 .. math::
47
48 w_i, g_i = \\arg \min_{w_i, g_i} L(Y, f_{i-1}(X) + w_i g_i)
49
50 On each iteration we fit a new weak learner to predict the negative
51 gradient of the loss with respect to the previous prediction, :math:`f_{i-1}(X)`.
52 We then use the element-wise product of the predictions of this weak
53 learner, :math:`g_i`, with a weight, :math:`w_i`, to compute the amount to
54 adjust the predictions of our model at the previous iteration, :math:`f_{i-1}(X)`:
55
56 .. math::
57
58 f_i(X) := f_{i-1}(X) + w_i g_i
59
60 Parameters
61 ----------
62 n_iter : int
63 The number of iterations / weak estimators to use when fitting each
64 dimension / class of `Y`.
65 max_depth : int
66 The maximum depth of each decision tree weak estimator. Default is
67 None.
68 classifier : bool
69 Whether `Y` contains class labels or real-valued targets. Default
70 is True.
71 learning_rate : float
72 Value in [0, 1] controlling the amount each weak estimator
73 contributes to the overall model prediction. Sometimes known as the
74 `shrinkage parameter` in the GBM literature. Default is 1.
75 loss : {'crossentropy', 'mse'}

Callers 2

plotFunction · 0.90
test_gbdtFunction · 0.90

Calls

no outgoing calls

Tested by 1

test_gbdtFunction · 0.72