MCPcopy
hub / github.com/ddbourgin/numpy-ml / fit

Method fit

numpy_ml/linear_models/glm.py:129–185  ·  view source on GitHub ↗

Find the maximum likelihood GLM coefficients via IRLS. Parameters ---------- X : :py:class:`ndarray ` of shape `(N, M)` A dataset consisting of `N` examples, each of dimension `M`. y : :py:class:`ndarray ` of shape `

(self, X, y)

Source from the content-addressed store, hash-verified

127 self.fit_intercept = fit_intercept
128
129 def fit(self, X, y):
130 """
131 Find the maximum likelihood GLM coefficients via IRLS.
132
133 Parameters
134 ----------
135 X : :py:class:`ndarray <numpy.ndarray>` of shape `(N, M)`
136 A dataset consisting of `N` examples, each of dimension `M`.
137 y : :py:class:`ndarray <numpy.ndarray>` of shape `(N,)`
138 The targets for each of the `N` examples in `X`.
139
140 Returns
141 -------
142 self : :class:`GeneralizedLinearModel <numpy_ml.linear_models.GeneralizedLinearModel>` instance
143 """ # noqa: E501
144 y = np.squeeze(y)
145 assert y.ndim == 1
146
147 N, M = X.shape
148 L = _GLM_LINKS[self.link]
149
150 # starting values for parameters
151 mu = np.ones_like(y) * np.mean(y)
152 eta = L["link"](mu)
153 theta = L["theta"](mu)
154
155 # convert X to a design matrix if we're fitting an intercept
156 if self.fit_intercept:
157 X = np.c_[np.ones(N), X]
158
159 # IRLS for GLM
160 i = 0
161 diff, beta = np.inf, np.inf
162 while diff > (self.tol * M):
163 if i > self.max_iter:
164 print("Warning: Model did not converge")
165 break
166
167 # compute first-order Taylor approx.
168 z = eta + (y - mu) * L["link_prime"](mu)
169 w = L["p"] / (L["b_prime2"](theta) * L["link_prime"](mu) ** 2)
170
171 # perform weighted least-squares on z
172 wlr = LinearRegression(fit_intercept=False)
173 beta_new = wlr.fit(X, z, weights=w).beta.ravel()
174
175 eta = X @ beta_new
176 mu = L["inv_link"](eta)
177 theta = L["theta"](mu)
178
179 diff = np.linalg.norm(beta - beta_new, ord=1)
180 beta = beta_new
181 i += 1
182
183 self.beta = beta
184 self._is_fit = True
185 return self
186

Callers 1

test_glmFunction · 0.95

Calls 2

fitMethod · 0.95
LinearRegressionClass · 0.90

Tested by 1

test_glmFunction · 0.76