Move all arrays to `xp` and `device`. Each array will be moved to the reference namespace and device if it is not already using it. Otherwise the array is left unchanged. `arrays` may contain `None` entries, these are left unchanged. Sparse arrays are accepted (as pass through) if
(*arrays, xp, device)
| 517 | |
| 518 | |
| 519 | def move_to(*arrays, xp, device): |
| 520 | """Move all arrays to `xp` and `device`. |
| 521 | |
| 522 | Each array will be moved to the reference namespace and device if |
| 523 | it is not already using it. Otherwise the array is left unchanged. |
| 524 | |
| 525 | `arrays` may contain `None` entries, these are left unchanged. |
| 526 | |
| 527 | Sparse arrays are accepted (as pass through) if the reference namespace is |
| 528 | NumPy, in which case they are returned unchanged. Otherwise a `TypeError` |
| 529 | is raised. |
| 530 | |
| 531 | Parameters |
| 532 | ---------- |
| 533 | *arrays : iterable of arrays |
| 534 | Arrays to (potentially) move. |
| 535 | |
| 536 | xp : namespace |
| 537 | Array API namespace to move arrays to. |
| 538 | |
| 539 | device : device |
| 540 | Array API device to move arrays to. |
| 541 | |
| 542 | Returns |
| 543 | ------- |
| 544 | arrays : tuple or array |
| 545 | Tuple of arrays with the same namespace and device as reference. Single array |
| 546 | returned if only one `arrays` input. |
| 547 | """ |
| 548 | if isinstance(device, str) and device == "xpu": # pragma: nocover |
| 549 | # XXX: Workaround for PyTorch XPU bug for `from_dlpack` calls with |
| 550 | # device strings that do not include any device number suffix. |
| 551 | # https://github.com/pytorch/pytorch/issues/181140 |
| 552 | device += ":0" |
| 553 | |
| 554 | sparse_mask = [sp.issparse(array) for array in arrays] |
| 555 | none_mask = [array is None for array in arrays] |
| 556 | if any(sparse_mask) and not _is_numpy_namespace(xp): |
| 557 | raise TypeError( |
| 558 | "Sparse arrays are only accepted (and passed through) when the target " |
| 559 | "namespace is Numpy" |
| 560 | ) |
| 561 | |
| 562 | arrays_ = arrays |
| 563 | # Down cast float64 `arrays` when highest precision of `xp`/`device` is float32 |
| 564 | if _max_precision_float_dtype(xp, device) == xp.float32: |
| 565 | arrays_ = [] |
| 566 | for array in arrays: |
| 567 | xp_array, _ = get_namespace(array) |
| 568 | if getattr(array, "dtype", None) == xp_array.float64: |
| 569 | arrays_.append(xp_array.astype(array, xp_array.float32)) |
| 570 | else: |
| 571 | arrays_.append(array) |
| 572 | |
| 573 | converted_arrays = [] |
| 574 | for array, is_sparse, is_none in zip(arrays_, sparse_mask, none_mask): |
| 575 | if is_none: |
| 576 | converted_arrays.append(None) |
searching dependent graphs…