MCPcopy
hub / github.com/mne-tools/mne-python / test_get_coef_multiclass_full

Function test_get_coef_multiclass_full

mne/decoding/tests/test_base.py:437–470  ·  view source on GitHub ↗

Test a full example with pattern extraction.

(n_classes, n_channels, n_times)

Source from the content-addressed store, hash-verified

435 ],
436)
437def test_get_coef_multiclass_full(n_classes, n_channels, n_times):
438 """Test a full example with pattern extraction."""
439 data = np.zeros((10 * n_classes, n_channels, n_times))
440 # Make only the first channel informative
441 for ii in range(n_classes):
442 data[ii * 10 : (ii + 1) * 10, 0] = ii
443 events = np.zeros((len(data), 3), int)
444 events[:, 0] = np.arange(len(events))
445 events[:, 2] = data[:, 0, 0]
446 info = create_info(n_channels, 1000.0, "eeg")
447 epochs = EpochsArray(data, info, events, tmin=0)
448 clf = make_pipeline(
449 Scaler(epochs.info),
450 Vectorizer(),
451 LinearModel(OneVsRestClassifier(LogisticRegression(random_state=0))),
452 )
453 scorer = "roc_auc_ovr_weighted"
454 time_gen = GeneralizingEstimator(clf, scorer, verbose=True)
455 X = epochs.get_data(copy=False)
456 y = epochs.events[:, 2]
457 n_splits = 3
458 cv = StratifiedKFold(n_splits=n_splits)
459 scores = cross_val_multiscore(time_gen, X, y, cv=cv, verbose=True)
460 want = (n_splits,)
461 if n_times > 1:
462 want += (n_times, n_times)
463 assert scores.shape == want
464 # On Windows LBFGS can fail to converge, so we need to be a bit more tol here
465 limit = 0.7 if platform.system() == "Windows" else 0.8
466 assert_array_less(limit, scores)
467 clf.fit(X, y)
468 patterns = get_coef(clf, "patterns_", inverse_transform=True)
469 assert patterns.shape == (n_classes, n_channels, n_times)
470 assert_allclose(patterns[:, 1:], 0.0, atol=1e-7) # no other channels useful
471
472
473def test_linearmodel():

Callers

nothing calls this directly

Calls 10

create_infoFunction · 0.90
EpochsArrayClass · 0.90
ScalerClass · 0.90
VectorizerClass · 0.90
LinearModelClass · 0.90
cross_val_multiscoreFunction · 0.90
get_coefFunction · 0.90
get_dataMethod · 0.45
fitMethod · 0.45

Tested by

no test coverage detected