MCPcopy
hub / github.com/mne-tools/mne-python / test_csp

Function test_csp

mne/decoding/tests/test_csp.py:118–262  ·  view source on GitHub ↗

Test Common Spatial Patterns algorithm on epochs.

()

Source from the content-addressed store, hash-verified

116
117@pytest.mark.slowtest
118def test_csp():
119 """Test Common Spatial Patterns algorithm on epochs."""
120 raw = io.read_raw_fif(raw_fname, preload=False)
121 events = read_events(event_name)
122 picks = pick_types(
123 raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads"
124 )
125 picks = picks[2:12:3] # subselect channels -> disable proj!
126 raw.add_proj([], remove_existing=True)
127 epochs = Epochs(
128 raw,
129 events,
130 event_id,
131 tmin,
132 tmax,
133 picks=picks,
134 baseline=(None, 0),
135 preload=True,
136 proj=False,
137 )
138 epochs_data = epochs.get_data(copy=False)
139 n_channels = epochs_data.shape[1]
140 y = epochs.events[:, -1]
141
142 # Init
143 csp = CSP(n_components="foo")
144 with pytest.raises(TypeError, match="must be an instance"):
145 csp.fit(epochs_data, y)
146 for reg in ["foo", -0.1, 1.1]:
147 csp = CSP(reg=reg, norm_trace=False)
148 pytest.raises(ValueError, csp.fit, epochs_data, epochs.events[:, -1])
149 for reg in ["oas", "ledoit_wolf", 0, 0.5, 1.0]:
150 CSP(reg=reg, norm_trace=False)
151 csp = CSP(cov_est="foo", norm_trace=False)
152 with pytest.raises(ValueError, match="Invalid value"):
153 csp.fit(epochs_data, y)
154 csp = CSP(norm_trace="foo")
155 with pytest.raises(TypeError, match="instance of bool"):
156 csp.fit(epochs_data, y)
157 for cov_est in ["concat", "epoch"]:
158 CSP(cov_est=cov_est, norm_trace=False).fit(epochs_data, y)
159
160 n_components = 3
161 # Fit
162 for norm_trace in [True, False]:
163 csp = CSP(n_components=n_components, norm_trace=norm_trace)
164 csp.fit(epochs_data, epochs.events[:, -1])
165
166 assert_equal(len(csp.mean_), n_components)
167 assert_equal(len(csp.std_), n_components)
168
169 # Transform
170 X = csp.fit_transform(epochs_data, y)
171 sources = csp.transform(epochs_data)
172 assert sources.shape[1] == n_components
173 assert csp.filters_.shape == (n_channels, n_channels)
174 assert csp.patterns_.shape == (n_channels, n_channels)
175 assert_array_almost_equal(sources, X)

Callers

nothing calls this directly

Calls 14

fitMethod · 0.95
fit_transformMethod · 0.95
transformMethod · 0.95
read_eventsFunction · 0.90
pick_typesFunction · 0.90
EpochsClass · 0.90
CSPClass · 0.90
add_projMethod · 0.80
simulate_dataFunction · 0.70
get_dataMethod · 0.45
fitMethod · 0.45
pickMethod · 0.45

Tested by

no test coverage detected