Return (xp, device) for array API testing. Parameters ---------- array_namespace : str The importable name of the array namespace module. device_name : str or None, default=None The device name for array allocation. Can be None for default device. Returns --
(array_namespace, device_name=None, dtype_name=None)
| 1306 | |
| 1307 | |
| 1308 | def _array_api_for_tests(array_namespace, device_name=None, dtype_name=None): |
| 1309 | """Return (xp, device) for array API testing. |
| 1310 | |
| 1311 | Parameters |
| 1312 | ---------- |
| 1313 | array_namespace : str |
| 1314 | The importable name of the array namespace module. |
| 1315 | device_name : str or None, default=None |
| 1316 | The device name for array allocation. Can be None for default device. |
| 1317 | |
| 1318 | Returns |
| 1319 | ------- |
| 1320 | xp : module |
| 1321 | The module object for the requested array namespace. |
| 1322 | device : object, str or None |
| 1323 | The library specific device object that can be passed to |
| 1324 | xp.asarray(..., device=device). This might be a string and not |
| 1325 | a library specific device object. |
| 1326 | """ |
| 1327 | try: |
| 1328 | array_mod = importlib.import_module(array_namespace) |
| 1329 | except (ModuleNotFoundError, ImportError): |
| 1330 | raise SkipTest( |
| 1331 | f"{array_namespace} is not installed: not checking array_api input" |
| 1332 | ) |
| 1333 | |
| 1334 | if os.environ.get("SCIPY_ARRAY_API") is None: |
| 1335 | raise SkipTest("SCIPY_ARRAY_API is not set: not checking array_api input") |
| 1336 | |
| 1337 | from sklearn.externals.array_api_compat import get_namespace |
| 1338 | |
| 1339 | # First create an array using the chosen array module and then get the |
| 1340 | # corresponding (compatibility wrapped) array namespace based on it. |
| 1341 | # This is because `cupy` is not the same as the compatibility wrapped |
| 1342 | # namespace of a CuPy array. |
| 1343 | device = None |
| 1344 | xp = get_namespace(array_mod.asarray(1)) |
| 1345 | if ( |
| 1346 | array_namespace == "torch" |
| 1347 | and device_name == "cuda" |
| 1348 | and not xp.backends.cuda.is_built() |
| 1349 | ): |
| 1350 | raise SkipTest("PyTorch test requires cuda, which is not available") |
| 1351 | elif array_namespace == "dpnp": # pragma: nocover |
| 1352 | dpctl = pytest.importorskip("dpctl") |
| 1353 | if device_name is None: |
| 1354 | available_devices = dpctl.get_devices() |
| 1355 | if not available_devices: |
| 1356 | raise SkipTest("Skipping dpnp test because no SYCL devices found") |
| 1357 | else: |
| 1358 | device = available_devices[0] |
| 1359 | elif not dpctl.get_devices(device_type=device_name): |
| 1360 | raise SkipTest(f"Skipping dpnp test because no {device_name} device found") |
| 1361 | |
| 1362 | elif array_namespace == "torch" and device_name == "mps": |
| 1363 | if os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1": |
| 1364 | # For now we need PYTORCH_ENABLE_MPS_FALLBACK=1 for all estimators to work |
| 1365 | # when using the MPS device. |
no test coverage detected
searching dependent graphs…