MCPcopy
hub / github.com/mne-tools/mne-python / test_XdawnTransformer

Function test_XdawnTransformer

mne/preprocessing/tests/test_xdawn.py:239–311  ·  view source on GitHub ↗

Test XdawnTransformer.

()

Source from the content-addressed store, hash-verified

237
238
239def test_XdawnTransformer():
240 """Test XdawnTransformer."""
241 pytest.importorskip("sklearn")
242 # Get data
243 raw, events, picks = _get_data()
244 raw.del_proj()
245 epochs = Epochs(
246 raw,
247 events,
248 event_id,
249 tmin,
250 tmax,
251 picks=picks,
252 preload=True,
253 baseline=None,
254 verbose=False,
255 )
256 X = epochs._data
257 y = epochs.events[:, -1]
258 # Fit
259 xdt = XdawnTransformer()
260 xdt.fit(X, y)
261 pytest.raises(ValueError, xdt.fit, X, y[1:])
262 pytest.raises(ValueError, xdt.fit, "foo")
263
264 # Provide covariance object
265 signal_cov = compute_raw_covariance(raw, picks=picks)
266 xdt = XdawnTransformer(signal_cov=signal_cov)
267 xdt.fit(X, y)
268 # Provide ndarray
269 signal_cov = np.eye(len(picks))
270 xdt = XdawnTransformer(signal_cov=signal_cov)
271 xdt.fit(X, y)
272 # Provide ndarray of bad shape
273 signal_cov = np.eye(len(picks) - 1)
274 xdt = XdawnTransformer(signal_cov=signal_cov)
275 pytest.raises(ValueError, xdt.fit, X, y)
276 # Provide another type
277 signal_cov = 42
278 xdt = XdawnTransformer(signal_cov=signal_cov)
279 pytest.raises(ValueError, xdt.fit, X, y)
280
281 # Fit with y as None
282 xdt = XdawnTransformer()
283 xdt.fit(X)
284
285 # Compare xdawn and XdawnTransformer
286 xd = Xdawn(correct_overlap=False)
287 xd.fit(epochs)
288
289 xdt = XdawnTransformer()
290 xdt.fit(X, y)
291 # Subset filters
292 xdt_filters = xdt._subset_multi_components()
293 assert_array_almost_equal(
294 xd.filters_["cond2"][:2, :], xdt_filters.reshape(2, 2, 8)[0]
295 )
296

Callers

nothing calls this directly

Calls 11

fitMethod · 0.95
fitMethod · 0.95
transformMethod · 0.95
inverse_transformMethod · 0.95
EpochsClass · 0.90
XdawnTransformerClass · 0.90
compute_raw_covarianceFunction · 0.90
XdawnClass · 0.90
del_projMethod · 0.80
_get_dataFunction · 0.70

Tested by

no test coverage detected