MCPcopy
hub / github.com/ddbourgin/numpy-ml / DecisionTree

Class DecisionTree

numpy_ml/trees/dt.py:21–212  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

19
20
21class DecisionTree:
22 def __init__(
23 self,
24 classifier=True,
25 max_depth=None,
26 n_feats=None,
27 criterion="entropy",
28 seed=None,
29 ):
30 """
31 A decision tree model for regression and classification problems.
32
33 Parameters
34 ----------
35 classifier : bool
36 Whether to treat target values as categorical (classifier =
37 True) or continuous (classifier = False). Default is True.
38 max_depth: int or None
39 The depth at which to stop growing the tree. If None, grow the tree
40 until all leaves are pure. Default is None.
41 n_feats : int
42 Specifies the number of features to sample on each split. If None,
43 use all features on each split. Default is None.
44 criterion : {'mse', 'entropy', 'gini'}
45 The error criterion to use when calculating splits. When
46 `classifier` is False, valid entries are {'mse'}. When `classifier`
47 is True, valid entries are {'entropy', 'gini'}. Default is
48 'entropy'.
49 seed : int or None
50 Seed for the random number generator. Default is None.
51 """
52 if seed:
53 np.random.seed(seed)
54
55 self.depth = 0
56 self.root = None
57
58 self.n_feats = n_feats
59 self.criterion = criterion
60 self.classifier = classifier
61 self.max_depth = max_depth if max_depth else np.inf
62
63 if not classifier and criterion in ["gini", "entropy"]:
64 raise ValueError(
65 "{} is a valid criterion only when classifier = True.".format(criterion)
66 )
67 if classifier and criterion == "mse":
68 raise ValueError("`mse` is a valid criterion only when classifier = False.")
69
70 def fit(self, X, Y):
71 """
72 Fit a binary decision tree to a dataset.
73
74 Parameters
75 ----------
76 X : :py:class:`ndarray <numpy.ndarray>` of shape `(N, M)`
77 The training data of `N` examples, each with `M` features
78 Y : :py:class:`ndarray <numpy.ndarray>` of shape `(N,)`

Callers 4

plotFunction · 0.90
test_DecisionTreeFunction · 0.90
fitMethod · 0.85
fitMethod · 0.85

Calls

no outgoing calls

Tested by 1

test_DecisionTreeFunction · 0.72