Recursive np.concatenate Input should be a nested list of numpy arrays arranged in the order they should appear in the array itself. Each array should have the same number of dimensions as the desired output and the nesting of the lists. >>> x = np.array([[1, 2]]) >>> concaten
(arrays)
| 5628 | |
| 5629 | |
| 5630 | def concatenate3(arrays): |
| 5631 | """Recursive np.concatenate |
| 5632 | |
| 5633 | Input should be a nested list of numpy arrays arranged in the order they |
| 5634 | should appear in the array itself. Each array should have the same number |
| 5635 | of dimensions as the desired output and the nesting of the lists. |
| 5636 | |
| 5637 | >>> x = np.array([[1, 2]]) |
| 5638 | >>> concatenate3([[x, x, x], [x, x, x]]) |
| 5639 | array([[1, 2, 1, 2, 1, 2], |
| 5640 | [1, 2, 1, 2, 1, 2]]) |
| 5641 | |
| 5642 | >>> concatenate3([[x, x], [x, x], [x, x]]) |
| 5643 | array([[1, 2, 1, 2], |
| 5644 | [1, 2, 1, 2], |
| 5645 | [1, 2, 1, 2]]) |
| 5646 | """ |
| 5647 | # We need this as __array_function__ may not exist on older NumPy versions. |
| 5648 | # And to reduce verbosity. |
| 5649 | NDARRAY_ARRAY_FUNCTION = getattr(np.ndarray, "__array_function__", None) |
| 5650 | |
| 5651 | arrays = concrete(arrays) |
| 5652 | if not arrays or all(el is None for el in flatten(arrays)): |
| 5653 | return np.empty(0) |
| 5654 | |
| 5655 | advanced = max( |
| 5656 | core.flatten(arrays, container=(list, tuple)), |
| 5657 | key=lambda x: getattr(x, "__array_priority__", 0), |
| 5658 | ) |
| 5659 | |
| 5660 | if not all( |
| 5661 | NDARRAY_ARRAY_FUNCTION |
| 5662 | is getattr(type(arr), "__array_function__", NDARRAY_ARRAY_FUNCTION) |
| 5663 | for arr in core.flatten(arrays, container=(list, tuple)) |
| 5664 | ): |
| 5665 | try: |
| 5666 | x = unpack_singleton(arrays) |
| 5667 | return _concatenate2(arrays, axes=tuple(range(x.ndim))) |
| 5668 | except TypeError: |
| 5669 | pass |
| 5670 | |
| 5671 | if concatenate_lookup.dispatch(type(advanced)) is not np.concatenate: |
| 5672 | x = unpack_singleton(arrays) |
| 5673 | return _concatenate2(arrays, axes=list(range(x.ndim))) |
| 5674 | |
| 5675 | ndim = ndimlist(arrays) |
| 5676 | if not ndim: |
| 5677 | return arrays |
| 5678 | chunks = chunks_from_arrays(arrays) |
| 5679 | shape = tuple(map(sum, chunks)) |
| 5680 | |
| 5681 | def dtype(x): |
| 5682 | try: |
| 5683 | return x.dtype |
| 5684 | except AttributeError: |
| 5685 | return type(x) |
| 5686 | |
| 5687 | result = np.empty(shape=shape, dtype=dtype(deepfirst(arrays))) |
searching dependent graphs…