| 9 | |
| 10 | |
| 11 | class RandomForest: |
| 12 | def __init__( |
| 13 | self, n_trees, max_depth, n_feats, classifier=True, criterion="entropy" |
| 14 | ): |
| 15 | """ |
| 16 | An ensemble (forest) of decision trees where each split is calculated |
| 17 | using a random subset of the features in the input. |
| 18 | |
| 19 | Parameters |
| 20 | ---------- |
| 21 | n_trees : int |
| 22 | The number of individual decision trees to use within the ensemble. |
| 23 | max_depth: int or None |
| 24 | The depth at which to stop growing each decision tree. If None, |
| 25 | grow each tree until the leaf nodes are pure. |
| 26 | n_feats : int |
| 27 | The number of features to sample on each split. |
| 28 | classifier : bool |
| 29 | Whether `Y` contains class labels or real-valued targets. Default |
| 30 | is True. |
| 31 | criterion : {'entropy', 'gini', 'mse'} |
| 32 | The error criterion to use when calculating splits for each weak |
| 33 | learner. When ``classifier = False``, valid entries are {'mse'}. |
| 34 | When ``classifier = True``, valid entries are {'entropy', 'gini'}. |
| 35 | Default is 'entropy'. |
| 36 | """ |
| 37 | self.trees = [] |
| 38 | self.n_trees = n_trees |
| 39 | self.n_feats = n_feats |
| 40 | self.max_depth = max_depth |
| 41 | self.criterion = criterion |
| 42 | self.classifier = classifier |
| 43 | |
| 44 | def fit(self, X, Y): |
| 45 | """ |
| 46 | Create `n_trees`-worth of bootstrapped samples from the training data |
| 47 | and use each to fit a separate decision tree. |
| 48 | """ |
| 49 | self.trees = [] |
| 50 | for _ in range(self.n_trees): |
| 51 | X_samp, Y_samp = bootstrap_sample(X, Y) |
| 52 | tree = DecisionTree( |
| 53 | n_feats=self.n_feats, |
| 54 | max_depth=self.max_depth, |
| 55 | criterion=self.criterion, |
| 56 | classifier=self.classifier, |
| 57 | ) |
| 58 | tree.fit(X_samp, Y_samp) |
| 59 | self.trees.append(tree) |
| 60 | |
| 61 | def predict(self, X): |
| 62 | """ |
| 63 | Predict the target value for each entry in `X`. |
| 64 | |
| 65 | Parameters |
| 66 | ---------- |
| 67 | X : :py:class:`ndarray <numpy.ndarray>` of shape `(N, M)` |
| 68 | The training data of `N` examples, each with `M` features. |
no outgoing calls