Called after collect is completed. Parameters ---------- config : pytest config items : list of collected items
(config, items)
| 129 | |
| 130 | |
| 131 | def pytest_collection_modifyitems(config, items): |
| 132 | """Called after collect is completed. |
| 133 | |
| 134 | Parameters |
| 135 | ---------- |
| 136 | config : pytest config |
| 137 | items : list of collected items |
| 138 | """ |
| 139 | run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0" |
| 140 | skip_network = pytest.mark.skip( |
| 141 | reason="test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0" |
| 142 | ) |
| 143 | |
| 144 | # download datasets during collection to avoid thread unsafe behavior |
| 145 | # when running pytest in parallel with pytest-xdist |
| 146 | dataset_features_set = set(dataset_fetchers) |
| 147 | datasets_to_download = set() |
| 148 | |
| 149 | for item in items: |
| 150 | if isinstance(item, DoctestItem) and "fetch_" in item.name: |
| 151 | fetcher_function_name = item.name.split(".")[-1] |
| 152 | dataset_fetchers_key = f"{fetcher_function_name}_fxt" |
| 153 | dataset_to_fetch = set([dataset_fetchers_key]) & dataset_features_set |
| 154 | elif not hasattr(item, "fixturenames"): |
| 155 | continue |
| 156 | else: |
| 157 | item_fixtures = set(item.fixturenames) |
| 158 | dataset_to_fetch = item_fixtures & dataset_features_set |
| 159 | |
| 160 | if not dataset_to_fetch: |
| 161 | continue |
| 162 | |
| 163 | if run_network_tests: |
| 164 | datasets_to_download |= dataset_to_fetch |
| 165 | else: |
| 166 | # network tests are skipped |
| 167 | item.add_marker(skip_network) |
| 168 | |
| 169 | # Only download datasets on the first worker spawned by pytest-xdist |
| 170 | # to avoid thread unsafe behavior. If pytest-xdist is not used, we still |
| 171 | # download before tests run. |
| 172 | worker_id = environ.get("PYTEST_XDIST_WORKER", "gw0") |
| 173 | if worker_id == "gw0" and run_network_tests: |
| 174 | for name in datasets_to_download: |
| 175 | with suppress(SkipTest): |
| 176 | dataset_fetchers[name]() |
| 177 | |
| 178 | for item in items: |
| 179 | # Known failure on with GradientBoostingClassifier on ARM64 |
| 180 | if ( |
| 181 | item.name.endswith("GradientBoostingClassifier") |
| 182 | and platform.machine() == "aarch64" |
| 183 | ): |
| 184 | marker = pytest.mark.xfail( |
| 185 | reason=( |
| 186 | "know failure. See " |
| 187 | "https://github.com/scikit-learn/scikit-learn/issues/17797" |
| 188 | ) |