Test XdawnTransformer.
()
| 237 | |
| 238 | |
| 239 | def test_XdawnTransformer(): |
| 240 | """Test XdawnTransformer.""" |
| 241 | pytest.importorskip("sklearn") |
| 242 | # Get data |
| 243 | raw, events, picks = _get_data() |
| 244 | raw.del_proj() |
| 245 | epochs = Epochs( |
| 246 | raw, |
| 247 | events, |
| 248 | event_id, |
| 249 | tmin, |
| 250 | tmax, |
| 251 | picks=picks, |
| 252 | preload=True, |
| 253 | baseline=None, |
| 254 | verbose=False, |
| 255 | ) |
| 256 | X = epochs._data |
| 257 | y = epochs.events[:, -1] |
| 258 | # Fit |
| 259 | xdt = XdawnTransformer() |
| 260 | xdt.fit(X, y) |
| 261 | pytest.raises(ValueError, xdt.fit, X, y[1:]) |
| 262 | pytest.raises(ValueError, xdt.fit, "foo") |
| 263 | |
| 264 | # Provide covariance object |
| 265 | signal_cov = compute_raw_covariance(raw, picks=picks) |
| 266 | xdt = XdawnTransformer(signal_cov=signal_cov) |
| 267 | xdt.fit(X, y) |
| 268 | # Provide ndarray |
| 269 | signal_cov = np.eye(len(picks)) |
| 270 | xdt = XdawnTransformer(signal_cov=signal_cov) |
| 271 | xdt.fit(X, y) |
| 272 | # Provide ndarray of bad shape |
| 273 | signal_cov = np.eye(len(picks) - 1) |
| 274 | xdt = XdawnTransformer(signal_cov=signal_cov) |
| 275 | pytest.raises(ValueError, xdt.fit, X, y) |
| 276 | # Provide another type |
| 277 | signal_cov = 42 |
| 278 | xdt = XdawnTransformer(signal_cov=signal_cov) |
| 279 | pytest.raises(ValueError, xdt.fit, X, y) |
| 280 | |
| 281 | # Fit with y as None |
| 282 | xdt = XdawnTransformer() |
| 283 | xdt.fit(X) |
| 284 | |
| 285 | # Compare xdawn and XdawnTransformer |
| 286 | xd = Xdawn(correct_overlap=False) |
| 287 | xd.fit(epochs) |
| 288 | |
| 289 | xdt = XdawnTransformer() |
| 290 | xdt.fit(X, y) |
| 291 | # Subset filters |
| 292 | xdt_filters = xdt._subset_multi_components() |
| 293 | assert_array_almost_equal( |
| 294 | xd.filters_["cond2"][:2, :], xdt_filters.reshape(2, 2, 8)[0] |
| 295 | ) |
| 296 |
nothing calls this directly
no test coverage detected