Test Common Spatial Patterns algorithm on epochs.
()
| 116 | |
| 117 | @pytest.mark.slowtest |
| 118 | def test_csp(): |
| 119 | """Test Common Spatial Patterns algorithm on epochs.""" |
| 120 | raw = io.read_raw_fif(raw_fname, preload=False) |
| 121 | events = read_events(event_name) |
| 122 | picks = pick_types( |
| 123 | raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" |
| 124 | ) |
| 125 | picks = picks[2:12:3] # subselect channels -> disable proj! |
| 126 | raw.add_proj([], remove_existing=True) |
| 127 | epochs = Epochs( |
| 128 | raw, |
| 129 | events, |
| 130 | event_id, |
| 131 | tmin, |
| 132 | tmax, |
| 133 | picks=picks, |
| 134 | baseline=(None, 0), |
| 135 | preload=True, |
| 136 | proj=False, |
| 137 | ) |
| 138 | epochs_data = epochs.get_data(copy=False) |
| 139 | n_channels = epochs_data.shape[1] |
| 140 | y = epochs.events[:, -1] |
| 141 | |
| 142 | # Init |
| 143 | csp = CSP(n_components="foo") |
| 144 | with pytest.raises(TypeError, match="must be an instance"): |
| 145 | csp.fit(epochs_data, y) |
| 146 | for reg in ["foo", -0.1, 1.1]: |
| 147 | csp = CSP(reg=reg, norm_trace=False) |
| 148 | pytest.raises(ValueError, csp.fit, epochs_data, epochs.events[:, -1]) |
| 149 | for reg in ["oas", "ledoit_wolf", 0, 0.5, 1.0]: |
| 150 | CSP(reg=reg, norm_trace=False) |
| 151 | csp = CSP(cov_est="foo", norm_trace=False) |
| 152 | with pytest.raises(ValueError, match="Invalid value"): |
| 153 | csp.fit(epochs_data, y) |
| 154 | csp = CSP(norm_trace="foo") |
| 155 | with pytest.raises(TypeError, match="instance of bool"): |
| 156 | csp.fit(epochs_data, y) |
| 157 | for cov_est in ["concat", "epoch"]: |
| 158 | CSP(cov_est=cov_est, norm_trace=False).fit(epochs_data, y) |
| 159 | |
| 160 | n_components = 3 |
| 161 | # Fit |
| 162 | for norm_trace in [True, False]: |
| 163 | csp = CSP(n_components=n_components, norm_trace=norm_trace) |
| 164 | csp.fit(epochs_data, epochs.events[:, -1]) |
| 165 | |
| 166 | assert_equal(len(csp.mean_), n_components) |
| 167 | assert_equal(len(csp.std_), n_components) |
| 168 | |
| 169 | # Transform |
| 170 | X = csp.fit_transform(epochs_data, y) |
| 171 | sources = csp.transform(epochs_data) |
| 172 | assert sources.shape[1] == n_components |
| 173 | assert csp.filters_.shape == (n_channels, n_channels) |
| 174 | assert csp.patterns_.shape == (n_channels, n_channels) |
| 175 | assert_array_almost_equal(sources, X) |
nothing calls this directly
no test coverage detected