MCPcopy
hub / github.com/mne-tools/mne-python / test_xdawn_save_load

Function test_xdawn_save_load

mne/decoding/tests/test_xdawn.py:22–73  ·  view source on GitHub ↗

Test that XdawnTransformer can be saved to disk and loaded correctly.

(tmp_path)

Source from the content-addressed store, hash-verified

20
21
22def test_xdawn_save_load(tmp_path):
23 """Test that XdawnTransformer can be saved to disk and loaded correctly."""
24 h5io = pytest.importorskip("h5io")
25 rng = np.random.RandomState(42)
26 n_epochs, n_channels, n_times = 40, 10, 50
27 X = rng.randn(n_epochs, n_channels, n_times)
28 y = rng.randint(0, 2, n_epochs)
29
30 xdawn = XdawnTransformer(n_components=2)
31 xdawn.fit(X, y)
32
33 state = xdawn.__getstate__()
34 assert "cov_callable" not in state
35 assert "mod_ged_callable" not in state
36
37 fname = tmp_path / "test_xdawn.h5"
38 xdawn.save(fname)
39
40 xdawn_loaded = read_xdawn_transformer(fname)
41
42 assert hasattr(xdawn_loaded, "cov_callable")
43 assert hasattr(xdawn_loaded, "mod_ged_callable")
44 assert callable(xdawn_loaded.cov_callable)
45 assert callable(xdawn_loaded.mod_ged_callable)
46
47 # Check fitted array attributes are restored
48 assert_array_almost_equal(xdawn.filters_, xdawn_loaded.filters_)
49 assert_array_almost_equal(xdawn.patterns_, xdawn_loaded.patterns_)
50
51 # Check scalar/param attributes
52 assert xdawn.n_components == xdawn_loaded.n_components
53 assert xdawn.reg == xdawn_loaded.reg
54 assert xdawn.rank == xdawn_loaded.rank
55 assert xdawn.restr_type == xdawn_loaded.restr_type
56
57 # Check transform output matches
58 X_orig = xdawn.transform(X)
59 X_loaded = xdawn_loaded.transform(X)
60 assert_array_almost_equal(X_orig, X_loaded)
61
62 with pytest.raises(FileExistsError):
63 xdawn.save(fname)
64 xdawn.save(fname, overwrite=True)
65
66 # Check that loading an HDF5 file with missing keys raises an error
67 bad_fname = tmp_path / "bad_xdawn.h5"
68 h5io.write_hdf5(bad_fname, dict(foo="bar"), title="mnepython", slash="replace")
69 with pytest.raises(ValueError, match="missing required keys"):
70 read_xdawn_transformer(bad_fname)
71
72 with pytest.raises(OSError, match="not found|does not exist"):
73 read_xdawn_transformer(tmp_path / "nonexistent.h5")

Callers

nothing calls this directly

Calls 7

fitMethod · 0.95
transformMethod · 0.95
XdawnTransformerClass · 0.90
read_xdawn_transformerFunction · 0.90
__getstate__Method · 0.45
saveMethod · 0.45
transformMethod · 0.45

Tested by

no test coverage detected