MCPcopy
hub / github.com/scikit-learn/scikit-learn / check_classifier_multioutput

Function check_classifier_multioutput

sklearn/utils/estimator_checks.py:2757–2815  ·  view source on GitHub ↗
(name, estimator_orig)

Source from the content-addressed store, hash-verified

2755
2756@ignore_warnings(category=FutureWarning)
2757def check_classifier_multioutput(name, estimator_orig):
2758 n_samples, n_labels, n_classes = 42, 5, 3
2759 tags = get_tags(estimator_orig)
2760 estimator = clone(estimator_orig)
2761 X, y = make_multilabel_classification(
2762 random_state=42, n_samples=n_samples, n_labels=n_labels, n_classes=n_classes
2763 )
2764 X = _enforce_estimator_tags_X(estimator, X)
2765 estimator.fit(X, y)
2766 y_pred = estimator.predict(X)
2767
2768 assert y_pred.shape == (n_samples, n_classes), (
2769 "The shape of the prediction for multioutput data is "
2770 "incorrect. Expected {}, got {}.".format((n_samples, n_labels), y_pred.shape)
2771 )
2772 assert y_pred.dtype.kind == "i"
2773
2774 if hasattr(estimator, "decision_function"):
2775 decision = estimator.decision_function(X)
2776 assert isinstance(decision, np.ndarray)
2777 assert decision.shape == (n_samples, n_classes), (
2778 "The shape of the decision function output for "
2779 "multioutput data is incorrect. Expected {}, got {}.".format(
2780 (n_samples, n_classes), decision.shape
2781 )
2782 )
2783
2784 dec_pred = (decision > 0).astype(int)
2785 dec_exp = estimator.classes_[dec_pred]
2786 assert_array_equal(dec_exp, y_pred)
2787
2788 if hasattr(estimator, "predict_proba"):
2789 y_prob = estimator.predict_proba(X)
2790
2791 if isinstance(y_prob, list) and not tags.classifier_tags.poor_score:
2792 for i in range(n_classes):
2793 assert y_prob[i].shape == (n_samples, 2), (
2794 "The shape of the probability for multioutput data is"
2795 " incorrect. Expected {}, got {}.".format(
2796 (n_samples, 2), y_prob[i].shape
2797 )
2798 )
2799 assert_array_equal(
2800 np.argmax(y_prob[i], axis=1).astype(int), y_pred[:, i]
2801 )
2802 elif not tags.classifier_tags.poor_score:
2803 assert y_prob.shape == (n_samples, n_classes), (
2804 "The shape of the probability for multioutput data is"
2805 " incorrect. Expected {}, got {}.".format(
2806 (n_samples, n_classes), y_prob.shape
2807 )
2808 )
2809 assert_array_equal(y_prob.round().astype(int), y_pred)
2810
2811 if hasattr(estimator, "decision_function") and hasattr(estimator, "predict_proba"):
2812 for i in range(n_classes):
2813 y_proba = estimator.predict_proba(X)[:, i]
2814 y_decision = estimator.decision_function(X)

Callers

nothing calls this directly

Calls 9

get_tagsFunction · 0.90
cloneFunction · 0.90
formatMethod · 0.80
fitMethod · 0.45
predictMethod · 0.45
decision_functionMethod · 0.45
predict_probaMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…