MCPcopy
hub / github.com/scikit-learn/scikit-learn / fit

Method fit

sklearn/kernel_approximation.py:1011–1074  ·  view source on GitHub ↗

Fit estimator to data. Samples a subset of training points, computes kernel on these and computes normalization matrix. Parameters ---------- X : array-like, shape (n_samples, n_features) Training data, where `n_samples` is the number of samples

(self, X, y=None)

Source from the content-addressed store, hash-verified

1009
1010 @_fit_context(prefer_skip_nested_validation=True)
1011 def fit(self, X, y=None):
1012 """Fit estimator to data.
1013
1014 Samples a subset of training points, computes kernel
1015 on these and computes normalization matrix.
1016
1017 Parameters
1018 ----------
1019 X : array-like, shape (n_samples, n_features)
1020 Training data, where `n_samples` is the number of samples
1021 and `n_features` is the number of features.
1022
1023 y : array-like, shape (n_samples,) or (n_samples, n_outputs), \
1024 default=None
1025 Target values (None for unsupervised transformations).
1026
1027 Returns
1028 -------
1029 self : object
1030 Returns the instance itself.
1031 """
1032 xp, _, device = get_namespace_and_device(X)
1033 X = validate_data(self, X, accept_sparse="csr")
1034 rnd = check_random_state(self.random_state)
1035 n_samples = X.shape[0]
1036
1037 # get basis vectors
1038 if self.n_components > n_samples:
1039 # XXX should we just bail?
1040 n_components = n_samples
1041 warnings.warn(
1042 "n_components > n_samples. This is not possible.\n"
1043 "n_components was set to n_samples, which results"
1044 " in inefficient evaluation of the full kernel."
1045 )
1046
1047 else:
1048 n_components = self.n_components
1049 n_components = min(n_samples, n_components)
1050 inds = rnd.permutation(n_samples)
1051 basis_inds = xp.asarray(inds[:n_components], dtype=xp.int64, device=device)
1052 if sp.issparse(X):
1053 basis = X[basis_inds]
1054 else:
1055 basis = _safe_indexing(X, basis_inds, axis=0)
1056
1057 basis_kernel = pairwise_kernels(
1058 basis,
1059 metric=self.kernel,
1060 filter_params=True,
1061 n_jobs=self.n_jobs,
1062 **self._get_kernel_params(),
1063 )
1064
1065 # sqrt of kernel matrix on basis vectors
1066 _, _, dtype = _find_floating_dtype_allow_sparse(basis_kernel, Y=None, xp=xp)
1067 basis_kernel = xp.asarray(basis_kernel, dtype=dtype, device=device)
1068 U, S, V = xp.linalg.svd(basis_kernel)

Calls 8

_get_kernel_paramsMethod · 0.95
get_namespace_and_deviceFunction · 0.90
validate_dataFunction · 0.90
check_random_stateFunction · 0.90
_safe_indexingFunction · 0.90
pairwise_kernelsFunction · 0.90
minFunction · 0.85

Tested by 4

test_nystroem_callableFunction · 0.76