MCPcopy Index your code
hub / github.com/scikit-learn/scikit-learn / _BaseChain

Class _BaseChain

sklearn/multioutput.py:638–842  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

636
637
638class _BaseChain(BaseEstimator, metaclass=ABCMeta):
639 _parameter_constraints: dict = {
640 "estimator": [HasMethods(["fit", "predict"])],
641 "order": ["array-like", StrOptions({"random"}), None],
642 "cv": ["cv_object", StrOptions({"prefit"})],
643 "random_state": ["random_state"],
644 "verbose": ["boolean"],
645 }
646
647 def __init__(
648 self,
649 estimator,
650 *,
651 order=None,
652 cv=None,
653 random_state=None,
654 verbose=False,
655 ):
656 self.estimator = estimator
657 self.order = order
658 self.cv = cv
659 self.random_state = random_state
660 self.verbose = verbose
661
662 def _log_message(self, *, estimator_idx, n_estimators, processing_msg):
663 if not self.verbose:
664 return None
665 return f"({estimator_idx} of {n_estimators}) {processing_msg}"
666
667 def _get_predictions(self, X, *, output_method):
668 """Get predictions for each model in the chain."""
669 check_is_fitted(self)
670 X = validate_data(self, X, accept_sparse=True, reset=False)
671 Y_output_chain = np.zeros((X.shape[0], len(self.estimators_)))
672 Y_feature_chain = np.zeros((X.shape[0], len(self.estimators_)))
673
674 # `RegressorChain` does not have a `chain_method_` parameter so we
675 # default to "predict"
676 chain_method = getattr(self, "chain_method_", "predict")
677 hstack = sp.hstack if sp.issparse(X) else np.hstack
678 for chain_idx, estimator in enumerate(self.estimators_):
679 previous_predictions = Y_feature_chain[:, :chain_idx]
680 # if `X` is a scipy sparse dok_array, we convert it to a sparse
681 # coo_array format before hstacking, it's faster; see
682 # https://github.com/scipy/scipy/issues/20060#issuecomment-1937007039:
683 if sp.issparse(X) and not sp.isspmatrix(X) and X.format == "dok":
684 X = sp.coo_array(X)
685 X_aug = hstack((X, previous_predictions))
686
687 feature_predictions, _ = _get_response_values(
688 estimator,
689 X_aug,
690 response_method=chain_method,
691 )
692 Y_feature_chain[:, chain_idx] = feature_predictions
693
694 output_predictions, _ = _get_response_values(
695 estimator,

Callers

nothing calls this directly

Calls 2

HasMethodsClass · 0.90
StrOptionsClass · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…