MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_DecisionTree

Function test_DecisionTree

numpy_ml/tests/test_trees.py:70–151  ·  view source on GitHub ↗
(N=1)

Source from the content-addressed store, hash-verified

68
69
70def test_DecisionTree(N=1):
71 i = 1
72 np.random.seed(12345)
73 while i <= N:
74 n_ex = np.random.randint(2, 100)
75 n_feats = np.random.randint(2, 100)
76 max_depth = np.random.randint(1, 5)
77
78 classifier = np.random.choice([True, False])
79 if classifier:
80 # create classification problem
81 n_classes = np.random.randint(2, 10)
82 X, Y = make_blobs(
83 n_samples=n_ex, centers=n_classes, n_features=n_feats, random_state=i
84 )
85 X, X_test, Y, Y_test = train_test_split(X, Y, test_size=0.3, random_state=i)
86
87 # initialize model
88 def loss(yp, y):
89 return 1 - accuracy_score(yp, y)
90
91 criterion = np.random.choice(["entropy", "gini"])
92 mine = DecisionTree(
93 classifier=classifier, max_depth=max_depth, criterion=criterion
94 )
95 gold = DecisionTreeClassifier(
96 criterion=criterion,
97 max_depth=max_depth,
98 splitter="best",
99 random_state=i,
100 )
101 else:
102 # create regeression problem
103 X, Y = make_regression(n_samples=n_ex, n_features=n_feats, random_state=i)
104 X, X_test, Y, Y_test = train_test_split(X, Y, test_size=0.3, random_state=i)
105
106 # initialize model
107 criterion = "mse"
108 loss = mean_squared_error
109 mine = DecisionTree(
110 criterion=criterion, max_depth=max_depth, classifier=classifier
111 )
112 gold = DecisionTreeRegressor(
113 criterion=criterion, max_depth=max_depth, splitter="best"
114 )
115
116 print("Trial {}".format(i))
117 print("\tClassifier={}, criterion={}".format(classifier, criterion))
118 print("\tmax_depth={}, n_feats={}, n_ex={}".format(max_depth, n_feats, n_ex))
119 if classifier:
120 print("\tn_classes: {}".format(n_classes))
121
122 # fit 'em
123 mine.fit(X, Y)
124 gold.fit(X, Y)
125
126 # get preds on training set
127 y_pred_mine = mine.predict(X)

Callers

nothing calls this directly

Calls 4

fitMethod · 0.95
predictMethod · 0.95
DecisionTreeClass · 0.90
lossFunction · 0.70

Tested by

no test coverage detected