MCPcopy
hub / github.com/mne-tools/mne-python / _bcd

Function _bcd

mne/inverse_sparse/mxne_optim.py:286–339  ·  view source on GitHub ↗

Implement one full pass of BCD. BCD stands for Block Coordinate Descent. This function make use of scipy.linalg.get_blas_funcs to speed reasons. Parameters ---------- G : array, shape (n_sensors, n_active) The gain matrix a.k.a. lead field. X : array, shape (n_sourc

(G, X, R, active_set, one_ovr_lc, n_orient, alpha_lc, list_G_j_c)

Source from the content-addressed store, hash-verified

284
285
286def _bcd(G, X, R, active_set, one_ovr_lc, n_orient, alpha_lc, list_G_j_c):
287 """Implement one full pass of BCD.
288
289 BCD stands for Block Coordinate Descent.
290 This function make use of scipy.linalg.get_blas_funcs to speed reasons.
291
292 Parameters
293 ----------
294 G : array, shape (n_sensors, n_active)
295 The gain matrix a.k.a. lead field.
296 X : array, shape (n_sources, n_times)
297 Sources, modified in place.
298 R : array, shape (n_sensors, n_times)
299 The residuals: R = M - G @ X, modified in place.
300 active_set : array of bool, shape (n_sources, )
301 Mask of active sources, modified in place.
302 one_ovr_lc : array, shape (n_positions, )
303 One over the lipschitz constants.
304 n_orient : int
305 Number of dipoles per positions (typically 1 or 3).
306 n_positions : int
307 Number of source positions.
308 alpha_lc: array, shape (n_positions, )
309 alpha * (Lipschitz constants).
310 """
311 X_j_new = np.zeros_like(X[:n_orient, :], order="C")
312 dgemm = _get_dgemm()
313
314 for j, G_j_c in enumerate(list_G_j_c):
315 idx = slice(j * n_orient, (j + 1) * n_orient)
316 G_j = G[:, idx]
317 X_j = X[idx]
318 dgemm(
319 alpha=one_ovr_lc[j], beta=0.0, a=R.T, b=G_j, c=X_j_new.T, overwrite_c=True
320 )
321 # X_j_new = G_j.T @ R
322 # Mathurin's trick to avoid checking all the entries
323 was_non_zero = X_j[0, 0] != 0
324 # was_non_zero = np.any(X_j)
325 if was_non_zero:
326 dgemm(alpha=1.0, beta=1.0, a=X_j.T, b=G_j_c.T, c=R.T, overwrite_c=True)
327 # R += np.dot(G_j, X_j)
328 X_j_new += X_j
329 block_norm = sqrt(sum_squared(X_j_new))
330 if block_norm <= alpha_lc[j]:
331 X_j.fill(0.0)
332 active_set[idx] = False
333 else:
334 shrink = max(1.0 - alpha_lc[j] / block_norm, 0.0)
335 X_j_new *= shrink
336 dgemm(alpha=-1.0, beta=1.0, a=X_j_new.T, b=G_j_c.T, c=R.T, overwrite_c=True)
337 # R -= np.dot(G_j, X_j_new)
338 X_j[:] = X_j_new
339 active_set[idx] = True
340
341
342@verbose

Callers 1

_mixed_norm_solver_bcdFunction · 0.85

Calls 3

_get_dgemmFunction · 0.85
sum_squaredFunction · 0.85
fillMethod · 0.80

Tested by

no test coverage detected