MCPcopy
hub / github.com/mne-tools/mne-python / test_get_coef

Function test_get_coef

mne/decoding/tests/test_base.py:99–230  ·  view source on GitHub ↗

Test getting linear coefficients (filters/patterns) from estimators.

()

Source from the content-addressed store, hash-verified

97
98
99def test_get_coef():
100 """Test getting linear coefficients (filters/patterns) from estimators."""
101 lm_classification = LinearModel(LogisticRegression(solver="liblinear"))
102 assert hasattr(lm_classification, "__sklearn_tags__")
103 if check_version("sklearn", "1.6"):
104 print(lm_classification.__sklearn_tags__())
105 assert is_classifier(lm_classification.model)
106 assert is_classifier(lm_classification)
107 assert not is_regressor(lm_classification.model)
108 assert not is_regressor(lm_classification)
109
110 lm_regression = LinearModel(Ridge())
111 assert is_regressor(lm_regression.model)
112 assert is_regressor(lm_regression)
113 assert not is_classifier(lm_regression.model)
114 assert not is_classifier(lm_regression)
115
116 parameters = {"kernel": ["linear"], "C": [1, 10]}
117 lm_gs_classification = LinearModel(
118 GridSearchCV(svm.SVC(), parameters, cv=2, refit=True, n_jobs=None)
119 )
120 assert is_classifier(lm_gs_classification)
121
122 lm_gs_regression = LinearModel(
123 GridSearchCV(svm.SVR(), parameters, cv=2, refit=True, n_jobs=None)
124 )
125 assert is_regressor(lm_gs_regression)
126
127 # Define a classifier, an invertible transformer and an non-invertible one.
128 assert BaseEstimator is sklearn_BaseEstimator
129 assert TransformerMixin is sklearn_TransformerMixin
130
131 class Clf(BaseEstimator):
132 def fit(self, X, y):
133 return self
134
135 class NoInv(TransformerMixin):
136 def fit(self, X, y):
137 return self
138
139 def transform(self, X):
140 return X
141
142 class Inv(NoInv):
143 def inverse_transform(self, X):
144 return X
145
146 X, y, A = _make_data(n_samples=1000, n_features=3, n_targets=1)
147
148 # I. Test inverse function
149
150 # Check that we retrieve the right number of inverse functions even if
151 # there are nested pipelines
152 good_estimators = [
153 (1, make_pipeline(Inv(), Clf())),
154 (2, make_pipeline(Inv(), Inv(), Clf())),
155 (3, make_pipeline(Inv(), make_pipeline(Inv(), Inv()), Clf())),
156 ]

Callers

nothing calls this directly

Calls 10

__sklearn_tags__Method · 0.95
LinearModelClass · 0.90
check_versionFunction · 0.90
_get_inverse_funcsFunction · 0.90
get_coefFunction · 0.90
InvClass · 0.85
ClfClass · 0.85
NoInvClass · 0.85
_make_dataFunction · 0.70
fitMethod · 0.45

Tested by

no test coverage detected