MCPcopy
hub / github.com/mne-tools/mne-python / test_search_light_basic

Function test_search_light_basic

mne/decoding/tests/test_search_light.py:44–192  ·  view source on GitHub ↗

Test SlidingEstimator.

()

Source from the content-addressed store, hash-verified

42
43
44def test_search_light_basic():
45 """Test SlidingEstimator."""
46 # https://github.com/scikit-learn/scikit-learn/issues/27711
47 if platform.system() == "Windows" and check_version("numpy", "2.0.0.dev0"):
48 pytest.skip("sklearn int_t / long long mismatch")
49
50 logreg = OneVsRestClassifier(LogisticRegression(solver="liblinear", random_state=0))
51
52 X, y = make_data()
53 n_epochs, _, n_time = X.shape
54 # init
55 sl = SlidingEstimator("foo")
56 with pytest.raises(ValueError, match="must be"):
57 sl.fit(X, y)
58 sl = SlidingEstimator(Ridge())
59 assert not is_classifier(sl)
60 sl = SlidingEstimator(LogisticRegression(solver="liblinear"))
61 assert is_classifier(sl.base_estimator)
62 assert is_classifier(sl)
63 # fit
64 assert_equal(sl.__repr__()[:18], "<SlidingEstimator(")
65 sl.fit(X, y)
66 assert_equal(sl.__repr__()[-28:], ", fitted with 10 estimators>")
67 pytest.raises(ValueError, sl.fit, X[1:], y)
68 pytest.raises(ValueError, sl.fit, X[:, :, 0], y)
69 sl.fit(X, y, sample_weight=np.ones_like(y))
70
71 # transforms
72 pytest.raises(ValueError, sl.predict, X[:, :, :2])
73 y_trans = sl.transform(X)
74 assert X.dtype == float
75 assert y_trans.dtype == float
76 y_pred = sl.predict(X)
77 assert y_pred.dtype == np.dtype(int)
78 assert_array_equal(y_pred.shape, [n_epochs, n_time])
79 y_proba = sl.predict_proba(X)
80 assert y_proba.dtype == np.dtype(float)
81 assert_array_equal(y_proba.shape, [n_epochs, n_time, 2])
82
83 # score
84 score = sl.score(X, y)
85 assert_array_equal(score.shape, [n_time])
86 assert np.sum(np.abs(score)) != 0
87 assert score.dtype == np.dtype(float)
88
89 sl = SlidingEstimator(logreg)
90 assert_equal(sl.scoring, None)
91
92 # Scoring method
93 for scoring in ["foo", 999]:
94 sl = SlidingEstimator(logreg, scoring=scoring)
95 sl.fit(X, y)
96 pytest.raises((ValueError, TypeError), sl.score, X, y)
97
98 # Check sklearn's roc_auc fix: scikit-learn/scikit-learn#6874
99 # -- 3 class problem
100 sl = SlidingEstimator(logreg, scoring="roc_auc")
101 y = np.arange(len(X)) % 3

Callers

nothing calls this directly

Calls 15

fitMethod · 0.95
__repr__Method · 0.95
transformMethod · 0.95
predictMethod · 0.95
predict_probaMethod · 0.95
scoreMethod · 0.95
decision_functionMethod · 0.95
check_versionFunction · 0.90
SlidingEstimatorClass · 0.90
roc_auc_scoreFunction · 0.90
VectorizerClass · 0.90
make_dataFunction · 0.85

Tested by

no test coverage detected