MCPcopy
hub / github.com/scikit-learn/scikit-learn / check_classifiers_train

Function check_classifiers_train

sklearn/utils/estimator_checks.py:2981–3116  ·  view source on GitHub ↗
(
    name, classifier_orig, readonly_memmap=False, X_dtype="float64"
)

Source from the content-addressed store, hash-verified

2979
2980@ignore_warnings # Warnings are raised by decision function
2981def check_classifiers_train(
2982 name, classifier_orig, readonly_memmap=False, X_dtype="float64"
2983):
2984 X_m, y_m = make_blobs(n_samples=300, random_state=0)
2985 X_m = X_m.astype(X_dtype)
2986 X_m, y_m = shuffle(X_m, y_m, random_state=7)
2987 X_m = StandardScaler().fit_transform(X_m)
2988 # generate binary problem from multi-class one
2989 y_b = y_m[y_m != 2]
2990 X_b = X_m[y_m != 2]
2991
2992 if name in ["BernoulliNB", "MultinomialNB", "ComplementNB", "CategoricalNB"]:
2993 X_m -= X_m.min()
2994 X_b -= X_b.min()
2995
2996 if readonly_memmap:
2997 X_m, y_m, X_b, y_b = create_memmap_backed_data([X_m, y_m, X_b, y_b])
2998
2999 problems = [(X_b, y_b)]
3000 tags = get_tags(classifier_orig)
3001 if tags.classifier_tags.multi_class:
3002 problems.append((X_m, y_m))
3003
3004 for X, y in problems:
3005 classes = np.unique(y)
3006 n_classes = len(classes)
3007 n_samples, n_features = X.shape
3008 classifier = clone(classifier_orig)
3009 X = _enforce_estimator_tags_X(classifier, X)
3010 y = _enforce_estimator_tags_y(classifier, y)
3011
3012 set_random_state(classifier)
3013 # raises error on malformed input for fit
3014 if not tags.no_validation:
3015 with raises(
3016 ValueError,
3017 err_msg=(
3018 f"The classifier {name} does not raise an error when "
3019 "incorrect/malformed input data for fit is passed. The number "
3020 "of training examples is not the same as the number of "
3021 "labels. Perhaps use check_X_y in fit."
3022 ),
3023 ):
3024 classifier.fit(X, y[:-1])
3025
3026 # fit
3027 classifier.fit(X, y)
3028 # with lists
3029 classifier.fit(X.tolist(), y.tolist())
3030 assert hasattr(classifier, "classes_")
3031 y_pred = classifier.predict(X)
3032
3033 assert y_pred.shape == (n_samples,)
3034 # training set performance
3035 if not tags.classifier_tags.poor_score:
3036 assert accuracy_score(y, y_pred) > 0.83
3037
3038 # raises error on malformed input for predict

Callers

nothing calls this directly

Calls 15

make_blobsFunction · 0.90
shuffleFunction · 0.90
StandardScalerClass · 0.90
get_tagsFunction · 0.90
cloneFunction · 0.90
set_random_stateFunction · 0.90
raisesFunction · 0.90
accuracy_scoreFunction · 0.90
assert_allcloseFunction · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…