Test get_coef on multiclass problems.
(n_features, n_targets)
| 374 | @pytest.mark.parametrize("n_features", [1, 5]) |
| 375 | @pytest.mark.parametrize("n_targets", [1, 3]) |
| 376 | def test_get_coef_multiclass(n_features, n_targets): |
| 377 | """Test get_coef on multiclass problems.""" |
| 378 | # Check patterns with more than 1 regressor |
| 379 | X, Y, A = _make_data(n_samples=30000, n_features=n_features, n_targets=n_targets) |
| 380 | lm = LinearModel(LinearRegression()) |
| 381 | assert not hasattr(lm, "model_") |
| 382 | lm.fit(X, Y) |
| 383 | assert lm.model is not lm.model_ |
| 384 | assert_array_equal(lm.filters_.shape, lm.patterns_.shape) |
| 385 | if n_targets == 1: |
| 386 | want_shape = (n_features,) |
| 387 | else: |
| 388 | want_shape = (n_targets, n_features) |
| 389 | assert_array_equal(lm.filters_.shape, want_shape) |
| 390 | if n_features > 1 and n_targets > 1: |
| 391 | assert_array_almost_equal(A, lm.patterns_.T, decimal=2) |
| 392 | lm = LinearModel(Ridge(alpha=0)) |
| 393 | clf = make_pipeline(lm) |
| 394 | clf.fit(X, Y) |
| 395 | if n_features > 1 and n_targets > 1: |
| 396 | assert_allclose(A, lm.patterns_.T, atol=2e-2) |
| 397 | coef = get_coef(clf, "patterns_", inverse_transform=True) |
| 398 | assert_allclose(lm.patterns_, coef, atol=1e-5) |
| 399 | |
| 400 | # With epochs, scaler, and vectorizer (typical use case) |
| 401 | X_epo = X.reshape(X.shape + (1,)) |
| 402 | info = create_info(n_features, 1000.0, "eeg") |
| 403 | lm = LinearModel(Ridge(alpha=1)) |
| 404 | clf = make_pipeline( |
| 405 | Scaler(info, scalings=dict(eeg=1.0)), # XXX adding this step breaks |
| 406 | Vectorizer(), |
| 407 | lm, |
| 408 | ) |
| 409 | clf.fit(X_epo, Y) |
| 410 | if n_features > 1 and n_targets > 1: |
| 411 | assert_allclose(A, lm.patterns_.T, atol=2e-2) |
| 412 | coef = get_coef(clf, "patterns_", inverse_transform=True) |
| 413 | |
| 414 | lm_patterns_ = lm.patterns_ |
| 415 | # Expected shape is (n_targets, n_features) |
| 416 | # which is equivalent to (n_components, n_channels) |
| 417 | # in spatial filters |
| 418 | if lm_patterns_.ndim == 1: |
| 419 | lm_patterns_ = lm_patterns_[np.newaxis, :] |
| 420 | else: |
| 421 | lm_patterns_ = lm_patterns_[..., np.newaxis] |
| 422 | assert_allclose(lm_patterns_, coef, atol=1e-5) |
| 423 | |
| 424 | # Check can pass fitting parameters |
| 425 | lm.fit(X, Y, sample_weight=np.ones(len(Y))) |
| 426 | |
| 427 | |
| 428 | @pytest.mark.parametrize( |
nothing calls this directly
no test coverage detected