(N=20)
| 8 | |
| 9 | |
| 10 | def test_glm(N=20): |
| 11 | np.random.seed(12345) |
| 12 | N = np.inf if N is None else N |
| 13 | |
| 14 | i = 1 |
| 15 | while i < N + 1: |
| 16 | n_samples = np.random.randint(10, 100) |
| 17 | |
| 18 | # n_feats << n_samples to avoid perfect separation / multiple solutions |
| 19 | n_feats = np.random.randint(1, 1 + n_samples // 2) |
| 20 | target_dim = 1 |
| 21 | |
| 22 | fit_intercept = np.random.choice([True, False]) |
| 23 | _link = np.random.choice(list(_GLM_LINKS.keys())) |
| 24 | |
| 25 | families = { |
| 26 | "identity": sm.families.Gaussian(), |
| 27 | "logit": sm.families.Binomial(), |
| 28 | "log": sm.families.Poisson(), |
| 29 | } |
| 30 | |
| 31 | print(f"Link: {_link}") |
| 32 | print(f"Fit intercept: {fit_intercept}") |
| 33 | |
| 34 | X = random_tensor((n_samples, n_feats), standardize=True) |
| 35 | if _link == "logit": |
| 36 | y = np.random.choice([0.0, 1.0], size=(n_samples, target_dim)) |
| 37 | elif _link == "log": |
| 38 | y = np.random.choice(np.arange(0, 100), size=(n_samples, target_dim)) |
| 39 | elif _link == "identity": |
| 40 | y = random_tensor((n_samples, target_dim), standardize=True) |
| 41 | else: |
| 42 | raise ValueError(f"Unknown link function {_link}") |
| 43 | |
| 44 | # Fit gold standard model on the entire dataset |
| 45 | fam = families[_link] |
| 46 | Xdesign = np.c_[np.ones(X.shape[0]), X] if fit_intercept else X |
| 47 | |
| 48 | glm_gold = sm.GLM(y, Xdesign, family=fam) |
| 49 | glm_gold = glm_gold.fit() |
| 50 | |
| 51 | glm_mine = GeneralizedLinearModel(link=_link, fit_intercept=fit_intercept) |
| 52 | glm_mine.fit(X, y) |
| 53 | |
| 54 | # check that model coefficients match |
| 55 | beta = glm_mine.beta.T.ravel() |
| 56 | np.testing.assert_almost_equal(beta, glm_gold.params, decimal=6) |
| 57 | print("\t1. Overall model coefficients match") |
| 58 | |
| 59 | # check that model predictions match |
| 60 | np.testing.assert_almost_equal( |
| 61 | glm_mine.predict(X), glm_gold.predict(Xdesign), decimal=5 |
| 62 | ) |
| 63 | print("\t2. Overall model predictions match") |
| 64 | |
| 65 | print("\tPASSED\n") |
| 66 | i += 1 |
nothing calls this directly
no test coverage detected