MCPcopy
hub / github.com/mne-tools/mne-python / predict

Method predict

mne/decoding/receptive_field.py:329–366  ·  view source on GitHub ↗

Generate predictions with a receptive field. Parameters ---------- X : array, shape (n_times[, n_epochs], n_channels) The input features for the model. Returns ------- y_pred : array, shape (n_times[, n_epochs][, n_outputs]) T

(self, X)

Source from the content-addressed store, hash-verified

327 return self
328
329 def predict(self, X):
330 """Generate predictions with a receptive field.
331
332 Parameters
333 ----------
334 X : array, shape (n_times[, n_epochs], n_channels)
335 The input features for the model.
336
337 Returns
338 -------
339 y_pred : array, shape (n_times[, n_epochs][, n_outputs])
340 The output predictions. "Note that valid samples (those
341 unaffected by edge artifacts during the time delaying step) can
342 be obtained using ``y_pred[rf.valid_samples_]``.
343 """
344 if not hasattr(self, "delays_"):
345 raise NotFittedError("Estimator has not been fit yet.")
346
347 X, _ = self._check_data(X)
348 X, _, X_dim = self._check_dimensions(X, None, predict=True)[:3]
349
350 del _
351 # convert to sklearn and back
352 pred_shape = X.shape[:-1]
353 if self._y_dim > 1:
354 pred_shape = pred_shape + (self.coef_.shape[0],)
355 X, _ = self._delay_and_reshape(X)
356 y_pred = self.estimator_.predict(X)
357 y_pred = y_pred.reshape(pred_shape, order="F")
358 shape = list(y_pred.shape)
359 if X_dim <= 2:
360 shape.pop(1) # epochs
361 extra = 0
362 else:
363 extra = 1
364 shape = shape[: self._y_dim + extra]
365 y_pred = _reshape_view(y_pred, shape)
366 return y_pred
367
368 def score(self, X, y):
369 """Score predictions generated with a receptive field.

Callers 5

scoreMethod · 0.95
test_rank_deficiencyFunction · 0.95
test_receptive_field_1dFunction · 0.95
test_receptive_field_ndFunction · 0.95

Calls 4

_check_dataMethod · 0.95
_check_dimensionsMethod · 0.95
_delay_and_reshapeMethod · 0.95
_reshape_viewFunction · 0.85

Tested by 4

test_rank_deficiencyFunction · 0.76
test_receptive_field_1dFunction · 0.76
test_receptive_field_ndFunction · 0.76