| 19 | |
| 20 | |
| 21 | class DecisionTree: |
| 22 | def __init__( |
| 23 | self, |
| 24 | classifier=True, |
| 25 | max_depth=None, |
| 26 | n_feats=None, |
| 27 | criterion="entropy", |
| 28 | seed=None, |
| 29 | ): |
| 30 | """ |
| 31 | A decision tree model for regression and classification problems. |
| 32 | |
| 33 | Parameters |
| 34 | ---------- |
| 35 | classifier : bool |
| 36 | Whether to treat target values as categorical (classifier = |
| 37 | True) or continuous (classifier = False). Default is True. |
| 38 | max_depth: int or None |
| 39 | The depth at which to stop growing the tree. If None, grow the tree |
| 40 | until all leaves are pure. Default is None. |
| 41 | n_feats : int |
| 42 | Specifies the number of features to sample on each split. If None, |
| 43 | use all features on each split. Default is None. |
| 44 | criterion : {'mse', 'entropy', 'gini'} |
| 45 | The error criterion to use when calculating splits. When |
| 46 | `classifier` is False, valid entries are {'mse'}. When `classifier` |
| 47 | is True, valid entries are {'entropy', 'gini'}. Default is |
| 48 | 'entropy'. |
| 49 | seed : int or None |
| 50 | Seed for the random number generator. Default is None. |
| 51 | """ |
| 52 | if seed: |
| 53 | np.random.seed(seed) |
| 54 | |
| 55 | self.depth = 0 |
| 56 | self.root = None |
| 57 | |
| 58 | self.n_feats = n_feats |
| 59 | self.criterion = criterion |
| 60 | self.classifier = classifier |
| 61 | self.max_depth = max_depth if max_depth else np.inf |
| 62 | |
| 63 | if not classifier and criterion in ["gini", "entropy"]: |
| 64 | raise ValueError( |
| 65 | "{} is a valid criterion only when classifier = True.".format(criterion) |
| 66 | ) |
| 67 | if classifier and criterion == "mse": |
| 68 | raise ValueError("`mse` is a valid criterion only when classifier = False.") |
| 69 | |
| 70 | def fit(self, X, Y): |
| 71 | """ |
| 72 | Fit a binary decision tree to a dataset. |
| 73 | |
| 74 | Parameters |
| 75 | ---------- |
| 76 | X : :py:class:`ndarray <numpy.ndarray>` of shape `(N, M)` |
| 77 | The training data of `N` examples, each with `M` features |
| 78 | Y : :py:class:`ndarray <numpy.ndarray>` of shape `(N,)` |
no outgoing calls