MCPcopy Index your code
hub / github.com/mne-tools/mne-python / test_ems

Function test_ems

mne/decoding/tests/test_ems.py:26–94  ·  view source on GitHub ↗

Test event-matched spatial filters.

()

Source from the content-addressed store, hash-verified

24
25
26def test_ems():
27 """Test event-matched spatial filters."""
28 raw = io.read_raw_fif(raw_fname, preload=False)
29
30 # create unequal number of events
31 events = read_events(event_name)
32 events[-2, 2] = 3
33 picks = pick_types(
34 raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads"
35 )
36 picks = picks[1:13:3]
37 epochs = Epochs(
38 raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True
39 )
40 pytest.raises(ValueError, compute_ems, epochs, ["aud_l", "vis_l"])
41 epochs.equalize_event_counts(epochs.event_id)
42
43 pytest.raises(KeyError, compute_ems, epochs, ["blah", "hahah"])
44 surrogates, filters, conditions = compute_ems(epochs)
45 assert_equal(list(set(conditions)), [1, 3])
46
47 events = read_events(event_name)
48 event_id2 = dict(aud_l=1, aud_r=2, vis_l=3)
49 epochs = Epochs(
50 raw,
51 events,
52 event_id2,
53 tmin,
54 tmax,
55 picks=picks,
56 baseline=(None, 0),
57 preload=True,
58 )
59 epochs.equalize_event_counts(epochs.event_id)
60
61 n_expected = sum([len(epochs[k]) for k in ["aud_l", "vis_l"]])
62
63 pytest.raises(ValueError, compute_ems, epochs)
64 surrogates, filters, conditions = compute_ems(epochs, ["aud_r", "vis_l"])
65 assert_equal(n_expected, len(surrogates))
66 assert_equal(n_expected, len(conditions))
67 assert_equal(list(set(conditions)), [2, 3])
68
69 # test compute_ems cv
70 epochs = epochs["aud_r", "vis_l"]
71 epochs.equalize_event_counts(epochs.event_id)
72 cv = StratifiedKFold(n_splits=3)
73 compute_ems(epochs, cv=cv)
74 compute_ems(epochs, cv=2)
75 pytest.raises(ValueError, compute_ems, epochs, cv="foo")
76 pytest.raises(ValueError, compute_ems, epochs, cv=len(epochs) + 1)
77 raw.close()
78
79 # EMS transformer, check that identical to compute_ems
80 X = epochs.get_data(copy=False)
81 y = epochs.events[:, 2]
82 X = X / np.std(X) # X scaled outside cv in compute_ems
83 Xt, coefs = list(), list()

Callers

nothing calls this directly

Calls 14

__repr__Method · 0.95
fitMethod · 0.95
transformMethod · 0.95
read_eventsFunction · 0.90
pick_typesFunction · 0.90
EpochsClass · 0.90
compute_emsFunction · 0.90
EMSClass · 0.90
setFunction · 0.85
equalize_event_countsMethod · 0.80
closeMethod · 0.45
get_dataMethod · 0.45

Tested by

no test coverage detected