Test SlidingEstimator.
()
| 42 | |
| 43 | |
| 44 | def test_search_light_basic(): |
| 45 | """Test SlidingEstimator.""" |
| 46 | # https://github.com/scikit-learn/scikit-learn/issues/27711 |
| 47 | if platform.system() == "Windows" and check_version("numpy", "2.0.0.dev0"): |
| 48 | pytest.skip("sklearn int_t / long long mismatch") |
| 49 | |
| 50 | logreg = OneVsRestClassifier(LogisticRegression(solver="liblinear", random_state=0)) |
| 51 | |
| 52 | X, y = make_data() |
| 53 | n_epochs, _, n_time = X.shape |
| 54 | # init |
| 55 | sl = SlidingEstimator("foo") |
| 56 | with pytest.raises(ValueError, match="must be"): |
| 57 | sl.fit(X, y) |
| 58 | sl = SlidingEstimator(Ridge()) |
| 59 | assert not is_classifier(sl) |
| 60 | sl = SlidingEstimator(LogisticRegression(solver="liblinear")) |
| 61 | assert is_classifier(sl.base_estimator) |
| 62 | assert is_classifier(sl) |
| 63 | # fit |
| 64 | assert_equal(sl.__repr__()[:18], "<SlidingEstimator(") |
| 65 | sl.fit(X, y) |
| 66 | assert_equal(sl.__repr__()[-28:], ", fitted with 10 estimators>") |
| 67 | pytest.raises(ValueError, sl.fit, X[1:], y) |
| 68 | pytest.raises(ValueError, sl.fit, X[:, :, 0], y) |
| 69 | sl.fit(X, y, sample_weight=np.ones_like(y)) |
| 70 | |
| 71 | # transforms |
| 72 | pytest.raises(ValueError, sl.predict, X[:, :, :2]) |
| 73 | y_trans = sl.transform(X) |
| 74 | assert X.dtype == float |
| 75 | assert y_trans.dtype == float |
| 76 | y_pred = sl.predict(X) |
| 77 | assert y_pred.dtype == np.dtype(int) |
| 78 | assert_array_equal(y_pred.shape, [n_epochs, n_time]) |
| 79 | y_proba = sl.predict_proba(X) |
| 80 | assert y_proba.dtype == np.dtype(float) |
| 81 | assert_array_equal(y_proba.shape, [n_epochs, n_time, 2]) |
| 82 | |
| 83 | # score |
| 84 | score = sl.score(X, y) |
| 85 | assert_array_equal(score.shape, [n_time]) |
| 86 | assert np.sum(np.abs(score)) != 0 |
| 87 | assert score.dtype == np.dtype(float) |
| 88 | |
| 89 | sl = SlidingEstimator(logreg) |
| 90 | assert_equal(sl.scoring, None) |
| 91 | |
| 92 | # Scoring method |
| 93 | for scoring in ["foo", 999]: |
| 94 | sl = SlidingEstimator(logreg, scoring=scoring) |
| 95 | sl.fit(X, y) |
| 96 | pytest.raises((ValueError, TypeError), sl.score, X, y) |
| 97 | |
| 98 | # Check sklearn's roc_auc fix: scikit-learn/scikit-learn#6874 |
| 99 | # -- 3 class problem |
| 100 | sl = SlidingEstimator(logreg, scoring="roc_auc") |
| 101 | y = np.arange(len(X)) % 3 |
nothing calls this directly
no test coverage detected