MCPcopy
hub / github.com/mne-tools/mne-python / test_pos_semidef_inv

Function test_pos_semidef_inv

mne/utils/tests/test_linalg.py:34–103  ·  view source on GitHub ↗

Test positive semidefinite matrix inverses.

(ndim, dtype, n, deficient, reduce_rank, psdef, func)

Source from the content-addressed store, hash-verified

32 ],
33)
34def test_pos_semidef_inv(ndim, dtype, n, deficient, reduce_rank, psdef, func):
35 """Test positive semidefinite matrix inverses."""
36 svd = np.linalg.svd
37 # make n-dimensional matrix
38 n_extra = 2 # how many we add along the other dims
39 rng = np.random.RandomState(73)
40 shape = (n_extra,) * (ndim - 2) + (n, n)
41 mat = rng.randn(*shape) + 1j * rng.randn(*shape)
42 proj = np.eye(n)
43 if deficient:
44 vec = np.ones(n) / np.sqrt(n)
45 proj -= np.outer(vec, vec)
46 with _record_warnings(): # intentionally discard imag
47 mat = mat.astype(dtype)
48 # now make it conjugate symmetric or positive semi-definite
49 if psdef:
50 mat = np.matmul(mat, mat.swapaxes(-2, -1).conj())
51 else:
52 mat += mat.swapaxes(-2, -1).conj()
53 assert_allclose(mat, mat.swapaxes(-2, -1).conj(), atol=1e-6)
54 s = svd(mat, hermitian=True)[1]
55 assert (s >= 0).all()
56 # make it rank deficient (maybe)
57 if deficient:
58 mat = np.matmul(np.matmul(proj, mat), proj)
59 # if the dtype is complex, the conjugate transpose != transpose
60 kwargs = dict(atol=1e-10, rtol=1e-10)
61 orig_eq_t = np.allclose(mat, mat.swapaxes(-2, -1), **kwargs)
62 t_eq_ct = np.allclose(mat.swapaxes(-2, -1), mat.conj().swapaxes(-2, -1), **kwargs)
63 if np.iscomplexobj(mat):
64 assert not orig_eq_t
65 assert not t_eq_ct
66 else:
67 assert t_eq_ct
68 assert orig_eq_t
69 assert mat.shape == shape
70 # ensure pos-semidef
71 s = np.linalg.svd(mat, compute_uv=False)
72 assert s.shape == shape[:-1]
73 rank = (s > s[..., :1] * 1e-12).sum(-1)
74 want_rank = n - deficient
75 assert_array_equal(rank, want_rank)
76 # assert equiv with NumPy
77 mat_pinv = np.linalg.pinv(mat)
78 if func is _sym_mat_pow:
79 if not psdef:
80 with pytest.raises(ValueError, match="not positive semi-"):
81 func(mat, -1)
82 return
83 mat_symv = func(mat, -1, reduce_rank=reduce_rank)
84 mat_sqrt = func(mat, 0.5)
85 if ndim == 2:
86 mat_sqrt_scipy = linalg.sqrtm(mat)
87 assert_allclose(mat_sqrt, mat_sqrt_scipy, atol=1e-6)
88 mat_2 = np.matmul(mat_sqrt, mat_sqrt)
89 assert_allclose(mat, mat_2, atol=1e-6)
90 mat_symv_2 = func(mat, -0.5, reduce_rank=reduce_rank)
91 mat_symv_2 = np.matmul(mat_symv_2, mat_symv_2)

Callers

nothing calls this directly

Calls 5

_record_warningsFunction · 0.90
sqrtMethod · 0.80
funcFunction · 0.50
sumMethod · 0.45
meanMethod · 0.45

Tested by

no test coverage detected