Test Common Spatial Patterns algorithm using regularized covariance.
(ch_type, rank, reg)
| 268 | @pytest.mark.parametrize("rank", (None, "full", "correct")) |
| 269 | @pytest.mark.parametrize("reg", [None, 0.001, "oas"]) |
| 270 | def test_regularized_csp(ch_type, rank, reg): |
| 271 | """Test Common Spatial Patterns algorithm using regularized covariance.""" |
| 272 | raw = io.read_raw_fif(raw_fname).pick(ch_type, exclude="bads").load_data() |
| 273 | n_orig = len(raw.ch_names) |
| 274 | ch_decim = 2 |
| 275 | raw.pick_channels(raw.ch_names[::ch_decim]) |
| 276 | raw.info.normalize_proj() |
| 277 | if "eeg" in ch_type: |
| 278 | raw.set_eeg_reference(projection=True) |
| 279 | # TODO: for some reason we need to add a second EEG projector in order to get |
| 280 | # the non-semidefinite error for EEG data. Hopefully this won't make much |
| 281 | # difference in practice given our default is rank=None and regularization |
| 282 | # is easy to use. |
| 283 | raw.add_proj(compute_proj_raw(raw, n_eeg=1, n_mag=0, n_grad=0, n_jobs=1)) |
| 284 | n_eig = len(raw.ch_names) - len(raw.info["projs"]) |
| 285 | n_ch = n_orig // ch_decim |
| 286 | if ch_type == "eeg": |
| 287 | assert n_eig == n_ch - 2 |
| 288 | elif ch_type == "mag": |
| 289 | assert n_eig == n_ch - 3 |
| 290 | else: |
| 291 | assert n_eig == n_ch - 5 |
| 292 | if rank == "correct": |
| 293 | if isinstance(ch_type, str): |
| 294 | rank = {ch_type: n_eig} |
| 295 | else: |
| 296 | assert ch_type == ("mag", "eeg") |
| 297 | rank = dict( |
| 298 | mag=102 // ch_decim - 3, |
| 299 | eeg=60 // ch_decim - 2, |
| 300 | ) |
| 301 | else: |
| 302 | assert rank is None or rank == "full", rank |
| 303 | if rank == "full": |
| 304 | n_eig = n_ch |
| 305 | raw.filter(2, 40).apply_proj() |
| 306 | events = read_events(event_name) |
| 307 | # map make left and right events the same |
| 308 | events[events[:, 2] == 2, 2] = 1 |
| 309 | events[events[:, 2] == 4, 2] = 3 |
| 310 | epochs = Epochs(raw, events, event_id, tmin, tmax, decim=5, preload=True) |
| 311 | epochs.equalize_event_counts() |
| 312 | assert 25 < len(epochs) < 30 |
| 313 | epochs_data = epochs.get_data(copy=False) |
| 314 | n_channels = epochs_data.shape[1] |
| 315 | assert n_channels == n_ch |
| 316 | n_components = 3 |
| 317 | |
| 318 | sc = Scaler(epochs.info) |
| 319 | epochs_data_orig = epochs_data.copy() |
| 320 | epochs_data = sc.fit_transform(epochs_data) |
| 321 | csp = CSP(n_components=n_components, reg=reg, norm_trace=False, rank=rank) |
| 322 | if rank == "full" and reg is None: |
| 323 | with pytest.raises(np.linalg.LinAlgError, match="leading minor"): |
| 324 | csp.fit(epochs_data, epochs.events[:, -1]) |
| 325 | return |
| 326 | with catch_logging(verbose=True) as log: |
| 327 | X = csp.fit_transform(epochs_data, epochs.events[:, -1]) |
nothing calls this directly
no test coverage detected