MCPcopy
hub / github.com/mne-tools/mne-python / test_linearmodel

Function test_linearmodel

mne/decoding/tests/test_base.py:473–536  ·  view source on GitHub ↗

Test LinearModel class for computing filters and patterns.

()

Source from the content-addressed store, hash-verified

471
472
473def test_linearmodel():
474 """Test LinearModel class for computing filters and patterns."""
475 # check categorical target fit in standard linear model
476 rng = np.random.RandomState(0)
477 clf = LinearModel()
478 n, n_features = 20, 3
479 X = rng.rand(n, n_features)
480 y = np.arange(n) % 2
481 clf.fit(X, y)
482 assert_equal(clf.filters_.shape, (n_features,))
483 assert_equal(clf.patterns_.shape, (n_features,))
484 with pytest.raises(ValueError):
485 wrong_X = rng.rand(n, n_features, 99)
486 clf.fit(wrong_X, y)
487
488 # check fit_transform call
489 clf = LinearModel(LinearDiscriminantAnalysis())
490 _ = clf.fit_transform(X, y)
491
492 # check that model has to have coef_, RBF-SVM doesn't
493 clf = LinearModel(svm.SVC(kernel="rbf"))
494 with pytest.raises(ValueError, match="does not have a `coef_`"):
495 clf.fit(X, y)
496
497 # check that model has to be a predictor
498 clf = LinearModel(StandardScaler())
499 with pytest.raises(ValueError, match="classifier or regressor"):
500 clf.fit(X, y)
501
502 # check categorical target fit in standard linear model with GridSearchCV
503 parameters = {"kernel": ["linear"], "C": [1, 10]}
504 clf = LinearModel(
505 GridSearchCV(svm.SVC(), parameters, cv=2, refit=True, n_jobs=None)
506 )
507 clf.fit(X, y)
508 assert_equal(clf.filters_.shape, (n_features,))
509 assert_equal(clf.patterns_.shape, (n_features,))
510 with pytest.raises(ValueError):
511 wrong_X = rng.rand(n, n_features, 99)
512 clf.fit(wrong_X, y)
513
514 # check continuous target fit in standard linear model with GridSearchCV
515 n_targets = 1
516 Y = rng.rand(n, n_targets)
517 clf = LinearModel(
518 GridSearchCV(svm.SVR(), parameters, cv=2, refit=True, n_jobs=None)
519 )
520 clf.fit(X, y)
521 assert_equal(clf.filters_.shape, (n_features,))
522 assert_equal(clf.patterns_.shape, (n_features,))
523 with pytest.raises(ValueError):
524 wrong_y = rng.rand(n, n_features, 99)
525 clf.fit(X, wrong_y)
526
527 # check multi-target fit in standard linear model
528 n_targets = 5
529 Y = rng.rand(n, n_targets)
530 clf = LinearModel(LinearRegression())

Callers

nothing calls this directly

Calls 3

fitMethod · 0.95
LinearModelClass · 0.90
fit_transformMethod · 0.45

Tested by

no test coverage detected