Rescale data sample-wise by square root of sample_weight. For many linear models, this enables easy support for sample_weight because (y - X w)' S (y - X w) with S = diag(sample_weight) becomes ||y_rescaled - X_rescaled w||_2^2 when setting y_rescaled = sqrt
(X, y, sample_weight, inplace=False)
| 221 | |
| 222 | |
| 223 | def _rescale_data(X, y, sample_weight, inplace=False): |
| 224 | """Rescale data sample-wise by square root of sample_weight. |
| 225 | |
| 226 | For many linear models, this enables easy support for sample_weight because |
| 227 | |
| 228 | (y - X w)' S (y - X w) |
| 229 | |
| 230 | with S = diag(sample_weight) becomes |
| 231 | |
| 232 | ||y_rescaled - X_rescaled w||_2^2 |
| 233 | |
| 234 | when setting |
| 235 | |
| 236 | y_rescaled = sqrt(S) y |
| 237 | X_rescaled = sqrt(S) X |
| 238 | |
| 239 | The parameter `inplace` only takes effect for dense X and dense y. |
| 240 | |
| 241 | Returns |
| 242 | ------- |
| 243 | X_rescaled : {array-like, sparse matrix} |
| 244 | |
| 245 | y_rescaled : {array-like, sparse matrix} |
| 246 | |
| 247 | sample_weight_sqrt : array-like of shape (n_samples,) |
| 248 | """ |
| 249 | # Assume that _validate_data and _check_sample_weight have been called by |
| 250 | # the caller. |
| 251 | xp, _ = get_namespace(X, y, sample_weight) |
| 252 | n_samples = X.shape[0] |
| 253 | sample_weight_sqrt = xp.sqrt(sample_weight) |
| 254 | |
| 255 | if sp.issparse(X) or sp.issparse(y): |
| 256 | sw_matrix = sparse.dia_array( |
| 257 | (sample_weight_sqrt, 0), shape=(n_samples, n_samples) |
| 258 | ) |
| 259 | |
| 260 | if sp.issparse(X): |
| 261 | X = safe_sparse_dot(sw_matrix, X) |
| 262 | else: |
| 263 | if inplace: |
| 264 | X *= sample_weight_sqrt[:, None] |
| 265 | else: |
| 266 | X = X * sample_weight_sqrt[:, None] |
| 267 | |
| 268 | if sp.issparse(y): |
| 269 | y = safe_sparse_dot(sw_matrix, y) |
| 270 | else: |
| 271 | if inplace: |
| 272 | if y.ndim == 1: |
| 273 | y *= sample_weight_sqrt |
| 274 | else: |
| 275 | y *= sample_weight_sqrt[:, None] |
| 276 | else: |
| 277 | if y.ndim == 1: |
| 278 | y = y * sample_weight_sqrt |
| 279 | else: |
| 280 | y = y * sample_weight_sqrt[:, None] |
searching dependent graphs…