MCPcopy
hub / github.com/scikit-learn/scikit-learn / _rescale_data

Function _rescale_data

sklearn/linear_model/_base.py:223–281  ·  view source on GitHub ↗

Rescale data sample-wise by square root of sample_weight. For many linear models, this enables easy support for sample_weight because (y - X w)' S (y - X w) with S = diag(sample_weight) becomes ||y_rescaled - X_rescaled w||_2^2 when setting y_rescaled = sqrt

(X, y, sample_weight, inplace=False)

Source from the content-addressed store, hash-verified

221
222
223def _rescale_data(X, y, sample_weight, inplace=False):
224 """Rescale data sample-wise by square root of sample_weight.
225
226 For many linear models, this enables easy support for sample_weight because
227
228 (y - X w)' S (y - X w)
229
230 with S = diag(sample_weight) becomes
231
232 ||y_rescaled - X_rescaled w||_2^2
233
234 when setting
235
236 y_rescaled = sqrt(S) y
237 X_rescaled = sqrt(S) X
238
239 The parameter `inplace` only takes effect for dense X and dense y.
240
241 Returns
242 -------
243 X_rescaled : {array-like, sparse matrix}
244
245 y_rescaled : {array-like, sparse matrix}
246
247 sample_weight_sqrt : array-like of shape (n_samples,)
248 """
249 # Assume that _validate_data and _check_sample_weight have been called by
250 # the caller.
251 xp, _ = get_namespace(X, y, sample_weight)
252 n_samples = X.shape[0]
253 sample_weight_sqrt = xp.sqrt(sample_weight)
254
255 if sp.issparse(X) or sp.issparse(y):
256 sw_matrix = sparse.dia_array(
257 (sample_weight_sqrt, 0), shape=(n_samples, n_samples)
258 )
259
260 if sp.issparse(X):
261 X = safe_sparse_dot(sw_matrix, X)
262 else:
263 if inplace:
264 X *= sample_weight_sqrt[:, None]
265 else:
266 X = X * sample_weight_sqrt[:, None]
267
268 if sp.issparse(y):
269 y = safe_sparse_dot(sw_matrix, y)
270 else:
271 if inplace:
272 if y.ndim == 1:
273 y *= sample_weight_sqrt
274 else:
275 y *= sample_weight_sqrt[:, None]
276 else:
277 if y.ndim == 1:
278 y = y * sample_weight_sqrt
279 else:
280 y = y * sample_weight_sqrt[:, None]

Callers 3

_ridge_regressionFunction · 0.90
test_rescale_dataFunction · 0.90
_preprocess_dataFunction · 0.85

Calls 3

get_namespaceFunction · 0.90
safe_sparse_dotFunction · 0.90
_align_api_if_sparseFunction · 0.90

Tested by 1

test_rescale_dataFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…