MCPcopy Index your code
hub / github.com/scikit-learn/scikit-learn / predict

Method predict

sklearn/multiclass.py:482–520  ·  view source on GitHub ↗

Predict multi-class targets using underlying estimators. Parameters ---------- X : {array-like, sparse matrix} of shape (n_samples, n_features) Data. Returns ------- y : {array-like, sparse matrix} of shape (n_samples,) or (n_samples, n_c

(self, X)

Source from the content-addressed store, hash-verified

480 return self
481
482 def predict(self, X):
483 """Predict multi-class targets using underlying estimators.
484
485 Parameters
486 ----------
487 X : {array-like, sparse matrix} of shape (n_samples, n_features)
488 Data.
489
490 Returns
491 -------
492 y : {array-like, sparse matrix} of shape (n_samples,) or (n_samples, n_classes)
493 Predicted multi-class targets.
494 """
495 check_is_fitted(self)
496
497 n_samples = _num_samples(X)
498 if self.label_binarizer_.y_type_ == "multiclass":
499 maxima = np.empty(n_samples, dtype=float)
500 maxima.fill(-np.inf)
501 argmaxima = np.zeros(n_samples, dtype=int)
502 n_classes = len(self.estimators_)
503 # Iterate in reverse order to match np.argmax tie-breaking behavior
504 for i, e in enumerate(reversed(self.estimators_)):
505 pred = _predict_binary(e, X)
506 np.maximum(maxima, pred, out=maxima)
507 argmaxima[maxima == pred] = n_classes - i - 1
508 return self.classes_[argmaxima]
509 else:
510 thresh = _threshold_for_binary_predict(self.estimators_[0])
511 indices = array.array("i")
512 indptr = array.array("i", [0])
513 for e in self.estimators_:
514 indices.extend(np.where(_predict_binary(e, X) > thresh)[0])
515 indptr.append(len(indices))
516 data = np.ones(len(indices), dtype=int)
517 indicator = sp.csc_array(
518 (data, indices, indptr), shape=(n_samples, len(self.estimators_))
519 )
520 return self.label_binarizer_.inverse_transform(indicator)
521
522 @available_if(_estimators_has("predict_proba"))
523 def predict_proba(self, X):

Callers 6

test_ovr_exceptionsFunction · 0.95
test_ovr_partial_fitFunction · 0.95
test_ovr_always_presentFunction · 0.95
test_ovr_pipelineFunction · 0.95
fit_singleFunction · 0.95

Calls 5

check_is_fittedFunction · 0.90
_num_samplesFunction · 0.90
_predict_binaryFunction · 0.85
inverse_transformMethod · 0.45

Tested by 5

test_ovr_exceptionsFunction · 0.76
test_ovr_partial_fitFunction · 0.76
test_ovr_always_presentFunction · 0.76
test_ovr_pipelineFunction · 0.76