MCPcopy
hub / github.com/scikit-learn/scikit-learn / _array_api_for_tests

Function _array_api_for_tests

sklearn/utils/_testing.py:1308–1409  ·  view source on GitHub ↗

Return (xp, device) for array API testing. Parameters ---------- array_namespace : str The importable name of the array namespace module. device_name : str or None, default=None The device name for array allocation. Can be None for default device. Returns --

(array_namespace, device_name=None, dtype_name=None)

Source from the content-addressed store, hash-verified

1306
1307
1308def _array_api_for_tests(array_namespace, device_name=None, dtype_name=None):
1309 """Return (xp, device) for array API testing.
1310
1311 Parameters
1312 ----------
1313 array_namespace : str
1314 The importable name of the array namespace module.
1315 device_name : str or None, default=None
1316 The device name for array allocation. Can be None for default device.
1317
1318 Returns
1319 -------
1320 xp : module
1321 The module object for the requested array namespace.
1322 device : object, str or None
1323 The library specific device object that can be passed to
1324 xp.asarray(..., device=device). This might be a string and not
1325 a library specific device object.
1326 """
1327 try:
1328 array_mod = importlib.import_module(array_namespace)
1329 except (ModuleNotFoundError, ImportError):
1330 raise SkipTest(
1331 f"{array_namespace} is not installed: not checking array_api input"
1332 )
1333
1334 if os.environ.get("SCIPY_ARRAY_API") is None:
1335 raise SkipTest("SCIPY_ARRAY_API is not set: not checking array_api input")
1336
1337 from sklearn.externals.array_api_compat import get_namespace
1338
1339 # First create an array using the chosen array module and then get the
1340 # corresponding (compatibility wrapped) array namespace based on it.
1341 # This is because `cupy` is not the same as the compatibility wrapped
1342 # namespace of a CuPy array.
1343 device = None
1344 xp = get_namespace(array_mod.asarray(1))
1345 if (
1346 array_namespace == "torch"
1347 and device_name == "cuda"
1348 and not xp.backends.cuda.is_built()
1349 ):
1350 raise SkipTest("PyTorch test requires cuda, which is not available")
1351 elif array_namespace == "dpnp": # pragma: nocover
1352 dpctl = pytest.importorskip("dpctl")
1353 if device_name is None:
1354 available_devices = dpctl.get_devices()
1355 if not available_devices:
1356 raise SkipTest("Skipping dpnp test because no SYCL devices found")
1357 else:
1358 device = available_devices[0]
1359 elif not dpctl.get_devices(device_type=device_name):
1360 raise SkipTest(f"Skipping dpnp test because no {device_name} device found")
1361
1362 elif array_namespace == "torch" and device_name == "mps":
1363 if os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1":
1364 # For now we need PYTORCH_ENABLE_MPS_FALLBACK=1 for all estimators to work
1365 # when using the MPS device.

Calls 3

get_namespaceFunction · 0.85
getMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…