MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_glm

Function test_glm

numpy_ml/tests/test_glm.py:10–66  ·  view source on GitHub ↗
(N=20)

Source from the content-addressed store, hash-verified

8
9
10def test_glm(N=20):
11 np.random.seed(12345)
12 N = np.inf if N is None else N
13
14 i = 1
15 while i < N + 1:
16 n_samples = np.random.randint(10, 100)
17
18 # n_feats << n_samples to avoid perfect separation / multiple solutions
19 n_feats = np.random.randint(1, 1 + n_samples // 2)
20 target_dim = 1
21
22 fit_intercept = np.random.choice([True, False])
23 _link = np.random.choice(list(_GLM_LINKS.keys()))
24
25 families = {
26 "identity": sm.families.Gaussian(),
27 "logit": sm.families.Binomial(),
28 "log": sm.families.Poisson(),
29 }
30
31 print(f"Link: {_link}")
32 print(f"Fit intercept: {fit_intercept}")
33
34 X = random_tensor((n_samples, n_feats), standardize=True)
35 if _link == "logit":
36 y = np.random.choice([0.0, 1.0], size=(n_samples, target_dim))
37 elif _link == "log":
38 y = np.random.choice(np.arange(0, 100), size=(n_samples, target_dim))
39 elif _link == "identity":
40 y = random_tensor((n_samples, target_dim), standardize=True)
41 else:
42 raise ValueError(f"Unknown link function {_link}")
43
44 # Fit gold standard model on the entire dataset
45 fam = families[_link]
46 Xdesign = np.c_[np.ones(X.shape[0]), X] if fit_intercept else X
47
48 glm_gold = sm.GLM(y, Xdesign, family=fam)
49 glm_gold = glm_gold.fit()
50
51 glm_mine = GeneralizedLinearModel(link=_link, fit_intercept=fit_intercept)
52 glm_mine.fit(X, y)
53
54 # check that model coefficients match
55 beta = glm_mine.beta.T.ravel()
56 np.testing.assert_almost_equal(beta, glm_gold.params, decimal=6)
57 print("\t1. Overall model coefficients match")
58
59 # check that model predictions match
60 np.testing.assert_almost_equal(
61 glm_mine.predict(X), glm_gold.predict(Xdesign), decimal=5
62 )
63 print("\t2. Overall model predictions match")
64
65 print("\tPASSED\n")
66 i += 1

Callers

nothing calls this directly

Calls 6

fitMethod · 0.95
predictMethod · 0.95
random_tensorFunction · 0.90
fitMethod · 0.45
predictMethod · 0.45

Tested by

no test coverage detected