MCPcopy
hub / github.com/mne-tools/mne-python / test_regularized_csp

Function test_regularized_csp

mne/decoding/tests/test_csp.py:270–380  ·  view source on GitHub ↗

Test Common Spatial Patterns algorithm using regularized covariance.

(ch_type, rank, reg)

Source from the content-addressed store, hash-verified

268@pytest.mark.parametrize("rank", (None, "full", "correct"))
269@pytest.mark.parametrize("reg", [None, 0.001, "oas"])
270def test_regularized_csp(ch_type, rank, reg):
271 """Test Common Spatial Patterns algorithm using regularized covariance."""
272 raw = io.read_raw_fif(raw_fname).pick(ch_type, exclude="bads").load_data()
273 n_orig = len(raw.ch_names)
274 ch_decim = 2
275 raw.pick_channels(raw.ch_names[::ch_decim])
276 raw.info.normalize_proj()
277 if "eeg" in ch_type:
278 raw.set_eeg_reference(projection=True)
279 # TODO: for some reason we need to add a second EEG projector in order to get
280 # the non-semidefinite error for EEG data. Hopefully this won't make much
281 # difference in practice given our default is rank=None and regularization
282 # is easy to use.
283 raw.add_proj(compute_proj_raw(raw, n_eeg=1, n_mag=0, n_grad=0, n_jobs=1))
284 n_eig = len(raw.ch_names) - len(raw.info["projs"])
285 n_ch = n_orig // ch_decim
286 if ch_type == "eeg":
287 assert n_eig == n_ch - 2
288 elif ch_type == "mag":
289 assert n_eig == n_ch - 3
290 else:
291 assert n_eig == n_ch - 5
292 if rank == "correct":
293 if isinstance(ch_type, str):
294 rank = {ch_type: n_eig}
295 else:
296 assert ch_type == ("mag", "eeg")
297 rank = dict(
298 mag=102 // ch_decim - 3,
299 eeg=60 // ch_decim - 2,
300 )
301 else:
302 assert rank is None or rank == "full", rank
303 if rank == "full":
304 n_eig = n_ch
305 raw.filter(2, 40).apply_proj()
306 events = read_events(event_name)
307 # map make left and right events the same
308 events[events[:, 2] == 2, 2] = 1
309 events[events[:, 2] == 4, 2] = 3
310 epochs = Epochs(raw, events, event_id, tmin, tmax, decim=5, preload=True)
311 epochs.equalize_event_counts()
312 assert 25 < len(epochs) < 30
313 epochs_data = epochs.get_data(copy=False)
314 n_channels = epochs_data.shape[1]
315 assert n_channels == n_ch
316 n_components = 3
317
318 sc = Scaler(epochs.info)
319 epochs_data_orig = epochs_data.copy()
320 epochs_data = sc.fit_transform(epochs_data)
321 csp = CSP(n_components=n_components, reg=reg, norm_trace=False, rank=rank)
322 if rank == "full" and reg is None:
323 with pytest.raises(np.linalg.LinAlgError, match="leading minor"):
324 csp.fit(epochs_data, epochs.events[:, -1])
325 return
326 with catch_logging(verbose=True) as log:
327 X = csp.fit_transform(epochs_data, epochs.events[:, -1])

Callers

nothing calls this directly

Calls 15

fit_transformMethod · 0.95
fitMethod · 0.95
fit_transformMethod · 0.95
transformMethod · 0.95
inverse_transformMethod · 0.95
compute_proj_rawFunction · 0.90
read_eventsFunction · 0.90
EpochsClass · 0.90
ScalerClass · 0.90
CSPClass · 0.90
catch_loggingClass · 0.90
LinearModelClass · 0.90

Tested by

no test coverage detected