| 99 | |
| 100 | |
| 101 | class RandomForestRegressor(RandomForest): |
| 102 | def __init__( |
| 103 | self, |
| 104 | n_estimators=10, |
| 105 | max_features=None, |
| 106 | min_samples_split=10, |
| 107 | max_depth=None, |
| 108 | criterion="mse", |
| 109 | ): |
| 110 | super(RandomForestRegressor, self).__init__( |
| 111 | n_estimators=n_estimators, |
| 112 | max_features=max_features, |
| 113 | min_samples_split=min_samples_split, |
| 114 | max_depth=max_depth, |
| 115 | ) |
| 116 | |
| 117 | if criterion == "mse": |
| 118 | self.criterion = mse_criterion |
| 119 | else: |
| 120 | raise ValueError() |
| 121 | |
| 122 | # Initialize empty regression trees |
| 123 | for _ in range(self.n_estimators): |
| 124 | self.trees.append(Tree(regression=True, criterion=self.criterion)) |
| 125 | |
| 126 | def _predict(self, X=None): |
| 127 | predictions = np.zeros((X.shape[0], self.n_estimators)) |
| 128 | for i, tree in enumerate(self.trees): |
| 129 | predictions[:, i] = tree.predict(X) |
| 130 | return predictions.mean(axis=1) |