Validation helper to check if the train/test sizes are meaningful w.r.t. the size of the data (n_samples).
(n_samples, test_size, train_size, default_test_size=None)
| 2443 | |
| 2444 | |
| 2445 | def _validate_shuffle_split(n_samples, test_size, train_size, default_test_size=None): |
| 2446 | """ |
| 2447 | Validation helper to check if the train/test sizes are meaningful w.r.t. the |
| 2448 | size of the data (n_samples). |
| 2449 | """ |
| 2450 | if test_size is None and train_size is None: |
| 2451 | test_size = default_test_size |
| 2452 | |
| 2453 | test_size_type = np.asarray(test_size).dtype.kind |
| 2454 | train_size_type = np.asarray(train_size).dtype.kind |
| 2455 | |
| 2456 | if (test_size_type == "i" and (test_size >= n_samples or test_size <= 0)) or ( |
| 2457 | test_size_type == "f" and (test_size <= 0 or test_size >= 1) |
| 2458 | ): |
| 2459 | raise ValueError( |
| 2460 | "test_size={0} should be either positive and smaller" |
| 2461 | " than the number of samples {1} or a float in the " |
| 2462 | "(0, 1) range".format(test_size, n_samples) |
| 2463 | ) |
| 2464 | |
| 2465 | if (train_size_type == "i" and (train_size >= n_samples or train_size <= 0)) or ( |
| 2466 | train_size_type == "f" and (train_size <= 0 or train_size >= 1) |
| 2467 | ): |
| 2468 | raise ValueError( |
| 2469 | "train_size={0} should be either positive and smaller" |
| 2470 | " than the number of samples {1} or a float in the " |
| 2471 | "(0, 1) range".format(train_size, n_samples) |
| 2472 | ) |
| 2473 | |
| 2474 | if train_size is not None and train_size_type not in ("i", "f"): |
| 2475 | raise ValueError("Invalid value for train_size: {}".format(train_size)) |
| 2476 | if test_size is not None and test_size_type not in ("i", "f"): |
| 2477 | raise ValueError("Invalid value for test_size: {}".format(test_size)) |
| 2478 | |
| 2479 | if train_size_type == "f" and test_size_type == "f" and train_size + test_size > 1: |
| 2480 | raise ValueError( |
| 2481 | "The sum of test_size and train_size = {}, should be in the (0, 1)" |
| 2482 | " range. Reduce test_size and/or train_size.".format(train_size + test_size) |
| 2483 | ) |
| 2484 | |
| 2485 | if test_size_type == "f": |
| 2486 | n_test = ceil(test_size * n_samples) |
| 2487 | elif test_size_type == "i": |
| 2488 | n_test = float(test_size) |
| 2489 | |
| 2490 | if train_size_type == "f": |
| 2491 | n_train = floor(train_size * n_samples) |
| 2492 | elif train_size_type == "i": |
| 2493 | n_train = float(train_size) |
| 2494 | |
| 2495 | if train_size is None: |
| 2496 | n_train = n_samples - n_test |
| 2497 | elif test_size is None: |
| 2498 | n_test = n_samples - n_train |
| 2499 | |
| 2500 | if n_train + n_test > n_samples: |
| 2501 | raise ValueError( |
| 2502 | "The sum of train_size and test_size = %d, " |
searching dependent graphs…