Compute the impurity gain associated with a given split. IG(split) = loss(parent) - weighted_avg[loss(left_child), loss(right_child)]
(self, Y, split_thresh, feat_values)
| 171 | return split_idx, split_thresh |
| 172 | |
| 173 | def _impurity_gain(self, Y, split_thresh, feat_values): |
| 174 | """ |
| 175 | Compute the impurity gain associated with a given split. |
| 176 | |
| 177 | IG(split) = loss(parent) - weighted_avg[loss(left_child), loss(right_child)] |
| 178 | """ |
| 179 | if self.criterion == "entropy": |
| 180 | loss = entropy |
| 181 | elif self.criterion == "gini": |
| 182 | loss = gini |
| 183 | elif self.criterion == "mse": |
| 184 | loss = mse |
| 185 | |
| 186 | parent_loss = loss(Y) |
| 187 | |
| 188 | # generate split |
| 189 | left = np.argwhere(feat_values <= split_thresh).flatten() |
| 190 | right = np.argwhere(feat_values > split_thresh).flatten() |
| 191 | |
| 192 | if len(left) == 0 or len(right) == 0: |
| 193 | return 0 |
| 194 | |
| 195 | # compute the weighted avg. of the loss for the children |
| 196 | n = len(Y) |
| 197 | n_l, n_r = len(left), len(right) |
| 198 | e_l, e_r = loss(Y[left]), loss(Y[right]) |
| 199 | child_loss = (n_l / n) * e_l + (n_r / n) * e_r |
| 200 | |
| 201 | # impurity gain is difference in loss before vs. after split |
| 202 | ig = parent_loss - child_loss |
| 203 | return ig |
| 204 | |
| 205 | def _traverse(self, X, node, prob=False): |
| 206 | if isinstance(node, Leaf): |