MCPcopy
hub / github.com/ddbourgin/numpy-ml / RandomForest

Class RandomForest

numpy_ml/trees/rf.py:11–99  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

9
10
11class RandomForest:
12 def __init__(
13 self, n_trees, max_depth, n_feats, classifier=True, criterion="entropy"
14 ):
15 """
16 An ensemble (forest) of decision trees where each split is calculated
17 using a random subset of the features in the input.
18
19 Parameters
20 ----------
21 n_trees : int
22 The number of individual decision trees to use within the ensemble.
23 max_depth: int or None
24 The depth at which to stop growing each decision tree. If None,
25 grow each tree until the leaf nodes are pure.
26 n_feats : int
27 The number of features to sample on each split.
28 classifier : bool
29 Whether `Y` contains class labels or real-valued targets. Default
30 is True.
31 criterion : {'entropy', 'gini', 'mse'}
32 The error criterion to use when calculating splits for each weak
33 learner. When ``classifier = False``, valid entries are {'mse'}.
34 When ``classifier = True``, valid entries are {'entropy', 'gini'}.
35 Default is 'entropy'.
36 """
37 self.trees = []
38 self.n_trees = n_trees
39 self.n_feats = n_feats
40 self.max_depth = max_depth
41 self.criterion = criterion
42 self.classifier = classifier
43
44 def fit(self, X, Y):
45 """
46 Create `n_trees`-worth of bootstrapped samples from the training data
47 and use each to fit a separate decision tree.
48 """
49 self.trees = []
50 for _ in range(self.n_trees):
51 X_samp, Y_samp = bootstrap_sample(X, Y)
52 tree = DecisionTree(
53 n_feats=self.n_feats,
54 max_depth=self.max_depth,
55 criterion=self.criterion,
56 classifier=self.classifier,
57 )
58 tree.fit(X_samp, Y_samp)
59 self.trees.append(tree)
60
61 def predict(self, X):
62 """
63 Predict the target value for each entry in `X`.
64
65 Parameters
66 ----------
67 X : :py:class:`ndarray <numpy.ndarray>` of shape `(N, M)`
68 The training data of `N` examples, each with `M` features.

Callers 2

plotFunction · 0.90
test_RandomForestFunction · 0.90

Calls

no outgoing calls

Tested by 1

test_RandomForestFunction · 0.72