Test that XdawnTransformer can be saved to disk and loaded correctly.
(tmp_path)
| 20 | |
| 21 | |
| 22 | def test_xdawn_save_load(tmp_path): |
| 23 | """Test that XdawnTransformer can be saved to disk and loaded correctly.""" |
| 24 | h5io = pytest.importorskip("h5io") |
| 25 | rng = np.random.RandomState(42) |
| 26 | n_epochs, n_channels, n_times = 40, 10, 50 |
| 27 | X = rng.randn(n_epochs, n_channels, n_times) |
| 28 | y = rng.randint(0, 2, n_epochs) |
| 29 | |
| 30 | xdawn = XdawnTransformer(n_components=2) |
| 31 | xdawn.fit(X, y) |
| 32 | |
| 33 | state = xdawn.__getstate__() |
| 34 | assert "cov_callable" not in state |
| 35 | assert "mod_ged_callable" not in state |
| 36 | |
| 37 | fname = tmp_path / "test_xdawn.h5" |
| 38 | xdawn.save(fname) |
| 39 | |
| 40 | xdawn_loaded = read_xdawn_transformer(fname) |
| 41 | |
| 42 | assert hasattr(xdawn_loaded, "cov_callable") |
| 43 | assert hasattr(xdawn_loaded, "mod_ged_callable") |
| 44 | assert callable(xdawn_loaded.cov_callable) |
| 45 | assert callable(xdawn_loaded.mod_ged_callable) |
| 46 | |
| 47 | # Check fitted array attributes are restored |
| 48 | assert_array_almost_equal(xdawn.filters_, xdawn_loaded.filters_) |
| 49 | assert_array_almost_equal(xdawn.patterns_, xdawn_loaded.patterns_) |
| 50 | |
| 51 | # Check scalar/param attributes |
| 52 | assert xdawn.n_components == xdawn_loaded.n_components |
| 53 | assert xdawn.reg == xdawn_loaded.reg |
| 54 | assert xdawn.rank == xdawn_loaded.rank |
| 55 | assert xdawn.restr_type == xdawn_loaded.restr_type |
| 56 | |
| 57 | # Check transform output matches |
| 58 | X_orig = xdawn.transform(X) |
| 59 | X_loaded = xdawn_loaded.transform(X) |
| 60 | assert_array_almost_equal(X_orig, X_loaded) |
| 61 | |
| 62 | with pytest.raises(FileExistsError): |
| 63 | xdawn.save(fname) |
| 64 | xdawn.save(fname, overwrite=True) |
| 65 | |
| 66 | # Check that loading an HDF5 file with missing keys raises an error |
| 67 | bad_fname = tmp_path / "bad_xdawn.h5" |
| 68 | h5io.write_hdf5(bad_fname, dict(foo="bar"), title="mnepython", slash="replace") |
| 69 | with pytest.raises(ValueError, match="missing required keys"): |
| 70 | read_xdawn_transformer(bad_fname) |
| 71 | |
| 72 | with pytest.raises(OSError, match="not found|does not exist"): |
| 73 | read_xdawn_transformer(tmp_path / "nonexistent.h5") |
nothing calls this directly
no test coverage detected