MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_GaussianNB

Function test_GaussianNB

numpy_ml/tests/test_naive_bayes.py:12–74  ·  view source on GitHub ↗
(N=10)

Source from the content-addressed store, hash-verified

10
11
12def test_GaussianNB(N=10):
13 np.random.seed(12345)
14 N = np.inf if N is None else N
15
16 i = 1
17 eps = np.finfo(float).eps
18 while i < N + 1:
19 n_ex = np.random.randint(1, 300)
20 n_feats = np.random.randint(1, 100)
21 n_classes = np.random.randint(2, 10)
22
23 X = random_tensor((n_ex, n_feats), standardize=True)
24 y = np.random.randint(0, n_classes, size=n_ex)
25
26 X_test = random_tensor((n_ex, n_feats), standardize=True)
27
28 NB = GaussianNBClassifier(eps=1e-09)
29 NB.fit(X, y)
30
31 preds = NB.predict(X_test)
32
33 sklearn_NB = naive_bayes.GaussianNB()
34 sklearn_NB.fit(X, y)
35
36 sk_preds = sklearn_NB.predict(X_test)
37
38 for j in range(len(NB.labels)):
39 P = NB.parameters
40 jointi = np.log(sklearn_NB.class_prior_[j])
41 jointi_mine = np.log(P["prior"][j])
42
43 np.testing.assert_almost_equal(jointi, jointi_mine)
44
45 n_jk = -0.5 * np.sum(np.log(2.0 * np.pi * sklearn_NB.sigma_[j, :] + eps))
46 n_jk_mine = -0.5 * np.sum(np.log(2.0 * np.pi * P["sigma"][j] + eps))
47
48 np.testing.assert_almost_equal(n_jk_mine, n_jk)
49
50 n_jk2 = n_jk - 0.5 * np.sum(
51 ((X_test - sklearn_NB.theta_[j, :]) ** 2) / (sklearn_NB.sigma_[j, :]), 1
52 )
53
54 n_jk2_mine = n_jk_mine - 0.5 * np.sum(
55 ((X_test - P["mean"][j]) ** 2) / (P["sigma"][j]), 1
56 )
57 np.testing.assert_almost_equal(n_jk2_mine, n_jk2, decimal=4)
58
59 llh = jointi + n_jk2
60 llh_mine = jointi_mine + n_jk2_mine
61
62 np.testing.assert_almost_equal(llh_mine, llh, decimal=4)
63
64 np.testing.assert_almost_equal(P["prior"], sklearn_NB.class_prior_)
65 np.testing.assert_almost_equal(P["mean"], sklearn_NB.theta_)
66 np.testing.assert_almost_equal(P["sigma"], sklearn_NB.sigma_)
67 np.testing.assert_almost_equal(
68 sklearn_NB._joint_log_likelihood(X_test),
69 NB._log_posterior(X_test),

Callers

nothing calls this directly

Calls 7

fitMethod · 0.95
predictMethod · 0.95
_log_posteriorMethod · 0.95
random_tensorFunction · 0.90
fitMethod · 0.45
predictMethod · 0.45

Tested by

no test coverage detected