MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_RandomForest

Function test_RandomForest

numpy_ml/tests/test_trees.py:154–252  ·  view source on GitHub ↗
(N=1)

Source from the content-addressed store, hash-verified

152
153
154def test_RandomForest(N=1):
155 np.random.seed(12345)
156 i = 1
157 while i <= N:
158 n_ex = np.random.randint(2, 100)
159 n_feats = np.random.randint(2, 100)
160 n_trees = np.random.randint(2, 100)
161 max_depth = np.random.randint(1, 5)
162
163 classifier = np.random.choice([True, False])
164 if classifier:
165 # create classification problem
166 n_classes = np.random.randint(2, 10)
167 X, Y = make_blobs(
168 n_samples=n_ex, centers=n_classes, n_features=n_feats, random_state=i
169 )
170 X, X_test, Y, Y_test = train_test_split(X, Y, test_size=0.3, random_state=i)
171
172 # initialize model
173 def loss(yp, y):
174 return 1 - accuracy_score(yp, y)
175
176 # initialize model
177 criterion = np.random.choice(["entropy", "gini"])
178 mine = RandomForest(
179 classifier=classifier,
180 n_feats=n_feats,
181 n_trees=n_trees,
182 criterion=criterion,
183 max_depth=max_depth,
184 )
185 gold = RandomForestClassifier(
186 n_estimators=n_trees,
187 max_features=n_feats,
188 criterion=criterion,
189 max_depth=max_depth,
190 bootstrap=True,
191 )
192 else:
193 # create regeression problem
194 X, Y = make_regression(n_samples=n_ex, n_features=n_feats, random_state=i)
195 X, X_test, Y, Y_test = train_test_split(X, Y, test_size=0.3, random_state=i)
196
197 # initialize model
198 criterion = "mse"
199 loss = mean_squared_error
200 mine = RandomForest(
201 criterion=criterion,
202 n_feats=n_feats,
203 n_trees=n_trees,
204 max_depth=max_depth,
205 classifier=classifier,
206 )
207 gold = RandomForestRegressor(
208 n_estimators=n_trees,
209 max_features=n_feats,
210 criterion=criterion,
211 max_depth=max_depth,

Callers

nothing calls this directly

Calls 6

fitMethod · 0.95
predictMethod · 0.95
RandomForestClass · 0.90
lossFunction · 0.70
fitMethod · 0.45
predictMethod · 0.45

Tested by

no test coverage detected