MCPcopy
hub / github.com/mne-tools/mne-python / test_get_coef_multiclass

Function test_get_coef_multiclass

mne/decoding/tests/test_base.py:376–425  ·  view source on GitHub ↗

Test get_coef on multiclass problems.

(n_features, n_targets)

Source from the content-addressed store, hash-verified

374@pytest.mark.parametrize("n_features", [1, 5])
375@pytest.mark.parametrize("n_targets", [1, 3])
376def test_get_coef_multiclass(n_features, n_targets):
377 """Test get_coef on multiclass problems."""
378 # Check patterns with more than 1 regressor
379 X, Y, A = _make_data(n_samples=30000, n_features=n_features, n_targets=n_targets)
380 lm = LinearModel(LinearRegression())
381 assert not hasattr(lm, "model_")
382 lm.fit(X, Y)
383 assert lm.model is not lm.model_
384 assert_array_equal(lm.filters_.shape, lm.patterns_.shape)
385 if n_targets == 1:
386 want_shape = (n_features,)
387 else:
388 want_shape = (n_targets, n_features)
389 assert_array_equal(lm.filters_.shape, want_shape)
390 if n_features > 1 and n_targets > 1:
391 assert_array_almost_equal(A, lm.patterns_.T, decimal=2)
392 lm = LinearModel(Ridge(alpha=0))
393 clf = make_pipeline(lm)
394 clf.fit(X, Y)
395 if n_features > 1 and n_targets > 1:
396 assert_allclose(A, lm.patterns_.T, atol=2e-2)
397 coef = get_coef(clf, "patterns_", inverse_transform=True)
398 assert_allclose(lm.patterns_, coef, atol=1e-5)
399
400 # With epochs, scaler, and vectorizer (typical use case)
401 X_epo = X.reshape(X.shape + (1,))
402 info = create_info(n_features, 1000.0, "eeg")
403 lm = LinearModel(Ridge(alpha=1))
404 clf = make_pipeline(
405 Scaler(info, scalings=dict(eeg=1.0)), # XXX adding this step breaks
406 Vectorizer(),
407 lm,
408 )
409 clf.fit(X_epo, Y)
410 if n_features > 1 and n_targets > 1:
411 assert_allclose(A, lm.patterns_.T, atol=2e-2)
412 coef = get_coef(clf, "patterns_", inverse_transform=True)
413
414 lm_patterns_ = lm.patterns_
415 # Expected shape is (n_targets, n_features)
416 # which is equivalent to (n_components, n_channels)
417 # in spatial filters
418 if lm_patterns_.ndim == 1:
419 lm_patterns_ = lm_patterns_[np.newaxis, :]
420 else:
421 lm_patterns_ = lm_patterns_[..., np.newaxis]
422 assert_allclose(lm_patterns_, coef, atol=1e-5)
423
424 # Check can pass fitting parameters
425 lm.fit(X, Y, sample_weight=np.ones(len(Y)))
426
427
428@pytest.mark.parametrize(

Callers

nothing calls this directly

Calls 8

fitMethod · 0.95
LinearModelClass · 0.90
get_coefFunction · 0.90
create_infoFunction · 0.90
ScalerClass · 0.90
VectorizerClass · 0.90
_make_dataFunction · 0.70
fitMethod · 0.45

Tested by

no test coverage detected