MCPcopy
hub / github.com/scikit-learn/scikit-learn / move_to

Function move_to

sklearn/utils/_array_api.py:519–628  ·  view source on GitHub ↗

Move all arrays to `xp` and `device`. Each array will be moved to the reference namespace and device if it is not already using it. Otherwise the array is left unchanged. `arrays` may contain `None` entries, these are left unchanged. Sparse arrays are accepted (as pass through) if

(*arrays, xp, device)

Source from the content-addressed store, hash-verified

517
518
519def move_to(*arrays, xp, device):
520 """Move all arrays to `xp` and `device`.
521
522 Each array will be moved to the reference namespace and device if
523 it is not already using it. Otherwise the array is left unchanged.
524
525 `arrays` may contain `None` entries, these are left unchanged.
526
527 Sparse arrays are accepted (as pass through) if the reference namespace is
528 NumPy, in which case they are returned unchanged. Otherwise a `TypeError`
529 is raised.
530
531 Parameters
532 ----------
533 *arrays : iterable of arrays
534 Arrays to (potentially) move.
535
536 xp : namespace
537 Array API namespace to move arrays to.
538
539 device : device
540 Array API device to move arrays to.
541
542 Returns
543 -------
544 arrays : tuple or array
545 Tuple of arrays with the same namespace and device as reference. Single array
546 returned if only one `arrays` input.
547 """
548 if isinstance(device, str) and device == "xpu": # pragma: nocover
549 # XXX: Workaround for PyTorch XPU bug for `from_dlpack` calls with
550 # device strings that do not include any device number suffix.
551 # https://github.com/pytorch/pytorch/issues/181140
552 device += ":0"
553
554 sparse_mask = [sp.issparse(array) for array in arrays]
555 none_mask = [array is None for array in arrays]
556 if any(sparse_mask) and not _is_numpy_namespace(xp):
557 raise TypeError(
558 "Sparse arrays are only accepted (and passed through) when the target "
559 "namespace is Numpy"
560 )
561
562 arrays_ = arrays
563 # Down cast float64 `arrays` when highest precision of `xp`/`device` is float32
564 if _max_precision_float_dtype(xp, device) == xp.float32:
565 arrays_ = []
566 for array in arrays:
567 xp_array, _ = get_namespace(array)
568 if getattr(array, "dtype", None) == xp_array.float64:
569 arrays_.append(xp_array.astype(array, xp_array.float32))
570 else:
571 arrays_.append(array)
572
573 converted_arrays = []
574 for array, is_sparse, is_none in zip(arrays_, sparse_mask, none_mask):
575 if is_none:
576 converted_arrays.append(None)

Callers 15

fitMethod · 0.90
predictMethod · 0.90
predictMethod · 0.90
cross_val_predictFunction · 0.90
_fit_and_predictFunction · 0.90
_make_test_foldsMethod · 0.90
_iter_indicesMethod · 0.90
train_test_splitFunction · 0.90
sampleMethod · 0.90

Calls 6

anyFunction · 0.85
_is_numpy_namespaceFunction · 0.85
get_namespaceFunction · 0.85
get_namespace_and_deviceFunction · 0.85
_convert_to_numpyFunction · 0.85

Used in the wild real call sites across dependent graphs

searching dependent graphs…