(name, estimator_orig)
| 2755 | |
| 2756 | @ignore_warnings(category=FutureWarning) |
| 2757 | def check_classifier_multioutput(name, estimator_orig): |
| 2758 | n_samples, n_labels, n_classes = 42, 5, 3 |
| 2759 | tags = get_tags(estimator_orig) |
| 2760 | estimator = clone(estimator_orig) |
| 2761 | X, y = make_multilabel_classification( |
| 2762 | random_state=42, n_samples=n_samples, n_labels=n_labels, n_classes=n_classes |
| 2763 | ) |
| 2764 | X = _enforce_estimator_tags_X(estimator, X) |
| 2765 | estimator.fit(X, y) |
| 2766 | y_pred = estimator.predict(X) |
| 2767 | |
| 2768 | assert y_pred.shape == (n_samples, n_classes), ( |
| 2769 | "The shape of the prediction for multioutput data is " |
| 2770 | "incorrect. Expected {}, got {}.".format((n_samples, n_labels), y_pred.shape) |
| 2771 | ) |
| 2772 | assert y_pred.dtype.kind == "i" |
| 2773 | |
| 2774 | if hasattr(estimator, "decision_function"): |
| 2775 | decision = estimator.decision_function(X) |
| 2776 | assert isinstance(decision, np.ndarray) |
| 2777 | assert decision.shape == (n_samples, n_classes), ( |
| 2778 | "The shape of the decision function output for " |
| 2779 | "multioutput data is incorrect. Expected {}, got {}.".format( |
| 2780 | (n_samples, n_classes), decision.shape |
| 2781 | ) |
| 2782 | ) |
| 2783 | |
| 2784 | dec_pred = (decision > 0).astype(int) |
| 2785 | dec_exp = estimator.classes_[dec_pred] |
| 2786 | assert_array_equal(dec_exp, y_pred) |
| 2787 | |
| 2788 | if hasattr(estimator, "predict_proba"): |
| 2789 | y_prob = estimator.predict_proba(X) |
| 2790 | |
| 2791 | if isinstance(y_prob, list) and not tags.classifier_tags.poor_score: |
| 2792 | for i in range(n_classes): |
| 2793 | assert y_prob[i].shape == (n_samples, 2), ( |
| 2794 | "The shape of the probability for multioutput data is" |
| 2795 | " incorrect. Expected {}, got {}.".format( |
| 2796 | (n_samples, 2), y_prob[i].shape |
| 2797 | ) |
| 2798 | ) |
| 2799 | assert_array_equal( |
| 2800 | np.argmax(y_prob[i], axis=1).astype(int), y_pred[:, i] |
| 2801 | ) |
| 2802 | elif not tags.classifier_tags.poor_score: |
| 2803 | assert y_prob.shape == (n_samples, n_classes), ( |
| 2804 | "The shape of the probability for multioutput data is" |
| 2805 | " incorrect. Expected {}, got {}.".format( |
| 2806 | (n_samples, n_classes), y_prob.shape |
| 2807 | ) |
| 2808 | ) |
| 2809 | assert_array_equal(y_prob.round().astype(int), y_pred) |
| 2810 | |
| 2811 | if hasattr(estimator, "decision_function") and hasattr(estimator, "predict_proba"): |
| 2812 | for i in range(n_classes): |
| 2813 | y_proba = estimator.predict_proba(X)[:, i] |
| 2814 | y_decision = estimator.decision_function(X) |
nothing calls this directly
no test coverage detected
searching dependent graphs…