Test event-matched spatial filters.
()
| 24 | |
| 25 | |
| 26 | def test_ems(): |
| 27 | """Test event-matched spatial filters.""" |
| 28 | raw = io.read_raw_fif(raw_fname, preload=False) |
| 29 | |
| 30 | # create unequal number of events |
| 31 | events = read_events(event_name) |
| 32 | events[-2, 2] = 3 |
| 33 | picks = pick_types( |
| 34 | raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" |
| 35 | ) |
| 36 | picks = picks[1:13:3] |
| 37 | epochs = Epochs( |
| 38 | raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True |
| 39 | ) |
| 40 | pytest.raises(ValueError, compute_ems, epochs, ["aud_l", "vis_l"]) |
| 41 | epochs.equalize_event_counts(epochs.event_id) |
| 42 | |
| 43 | pytest.raises(KeyError, compute_ems, epochs, ["blah", "hahah"]) |
| 44 | surrogates, filters, conditions = compute_ems(epochs) |
| 45 | assert_equal(list(set(conditions)), [1, 3]) |
| 46 | |
| 47 | events = read_events(event_name) |
| 48 | event_id2 = dict(aud_l=1, aud_r=2, vis_l=3) |
| 49 | epochs = Epochs( |
| 50 | raw, |
| 51 | events, |
| 52 | event_id2, |
| 53 | tmin, |
| 54 | tmax, |
| 55 | picks=picks, |
| 56 | baseline=(None, 0), |
| 57 | preload=True, |
| 58 | ) |
| 59 | epochs.equalize_event_counts(epochs.event_id) |
| 60 | |
| 61 | n_expected = sum([len(epochs[k]) for k in ["aud_l", "vis_l"]]) |
| 62 | |
| 63 | pytest.raises(ValueError, compute_ems, epochs) |
| 64 | surrogates, filters, conditions = compute_ems(epochs, ["aud_r", "vis_l"]) |
| 65 | assert_equal(n_expected, len(surrogates)) |
| 66 | assert_equal(n_expected, len(conditions)) |
| 67 | assert_equal(list(set(conditions)), [2, 3]) |
| 68 | |
| 69 | # test compute_ems cv |
| 70 | epochs = epochs["aud_r", "vis_l"] |
| 71 | epochs.equalize_event_counts(epochs.event_id) |
| 72 | cv = StratifiedKFold(n_splits=3) |
| 73 | compute_ems(epochs, cv=cv) |
| 74 | compute_ems(epochs, cv=2) |
| 75 | pytest.raises(ValueError, compute_ems, epochs, cv="foo") |
| 76 | pytest.raises(ValueError, compute_ems, epochs, cv=len(epochs) + 1) |
| 77 | raw.close() |
| 78 | |
| 79 | # EMS transformer, check that identical to compute_ems |
| 80 | X = epochs.get_data(copy=False) |
| 81 | y = epochs.events[:, 2] |
| 82 | X = X / np.std(X) # X scaled outside cv in compute_ems |
| 83 | Xt, coefs = list(), list() |
nothing calls this directly
no test coverage detected