(
name, classifier_orig, readonly_memmap=False, X_dtype="float64"
)
| 2979 | |
| 2980 | @ignore_warnings # Warnings are raised by decision function |
| 2981 | def check_classifiers_train( |
| 2982 | name, classifier_orig, readonly_memmap=False, X_dtype="float64" |
| 2983 | ): |
| 2984 | X_m, y_m = make_blobs(n_samples=300, random_state=0) |
| 2985 | X_m = X_m.astype(X_dtype) |
| 2986 | X_m, y_m = shuffle(X_m, y_m, random_state=7) |
| 2987 | X_m = StandardScaler().fit_transform(X_m) |
| 2988 | # generate binary problem from multi-class one |
| 2989 | y_b = y_m[y_m != 2] |
| 2990 | X_b = X_m[y_m != 2] |
| 2991 | |
| 2992 | if name in ["BernoulliNB", "MultinomialNB", "ComplementNB", "CategoricalNB"]: |
| 2993 | X_m -= X_m.min() |
| 2994 | X_b -= X_b.min() |
| 2995 | |
| 2996 | if readonly_memmap: |
| 2997 | X_m, y_m, X_b, y_b = create_memmap_backed_data([X_m, y_m, X_b, y_b]) |
| 2998 | |
| 2999 | problems = [(X_b, y_b)] |
| 3000 | tags = get_tags(classifier_orig) |
| 3001 | if tags.classifier_tags.multi_class: |
| 3002 | problems.append((X_m, y_m)) |
| 3003 | |
| 3004 | for X, y in problems: |
| 3005 | classes = np.unique(y) |
| 3006 | n_classes = len(classes) |
| 3007 | n_samples, n_features = X.shape |
| 3008 | classifier = clone(classifier_orig) |
| 3009 | X = _enforce_estimator_tags_X(classifier, X) |
| 3010 | y = _enforce_estimator_tags_y(classifier, y) |
| 3011 | |
| 3012 | set_random_state(classifier) |
| 3013 | # raises error on malformed input for fit |
| 3014 | if not tags.no_validation: |
| 3015 | with raises( |
| 3016 | ValueError, |
| 3017 | err_msg=( |
| 3018 | f"The classifier {name} does not raise an error when " |
| 3019 | "incorrect/malformed input data for fit is passed. The number " |
| 3020 | "of training examples is not the same as the number of " |
| 3021 | "labels. Perhaps use check_X_y in fit." |
| 3022 | ), |
| 3023 | ): |
| 3024 | classifier.fit(X, y[:-1]) |
| 3025 | |
| 3026 | # fit |
| 3027 | classifier.fit(X, y) |
| 3028 | # with lists |
| 3029 | classifier.fit(X.tolist(), y.tolist()) |
| 3030 | assert hasattr(classifier, "classes_") |
| 3031 | y_pred = classifier.predict(X) |
| 3032 | |
| 3033 | assert y_pred.shape == (n_samples,) |
| 3034 | # training set performance |
| 3035 | if not tags.classifier_tags.poor_score: |
| 3036 | assert accuracy_score(y, y_pred) > 0.83 |
| 3037 | |
| 3038 | # raises error on malformed input for predict |
nothing calls this directly
no test coverage detected
searching dependent graphs…