(a, offset=0, axis1=0, axis2=1)
| 689 | |
| 690 | @derived_from(np) |
| 691 | def diagonal(a, offset=0, axis1=0, axis2=1): |
| 692 | name = "diagonal-" + tokenize(a, offset, axis1, axis2) |
| 693 | |
| 694 | if a.ndim < 2: |
| 695 | # NumPy uses `diag` as we do here. |
| 696 | raise ValueError("diag requires an array of at least two dimensions") |
| 697 | |
| 698 | def _axis_fmt(axis, name, ndim): |
| 699 | if axis < 0: |
| 700 | t = ndim + axis |
| 701 | if t < 0: |
| 702 | msg = "{}: axis {} is out of bounds for array of dimension {}" |
| 703 | raise AxisError(msg.format(name, axis, ndim)) |
| 704 | axis = t |
| 705 | return axis |
| 706 | |
| 707 | def pop_axes(chunks, axis1, axis2): |
| 708 | chunks = list(chunks) |
| 709 | chunks.pop(axis2) |
| 710 | chunks.pop(axis1) |
| 711 | return tuple(chunks) |
| 712 | |
| 713 | axis1 = _axis_fmt(axis1, "axis1", a.ndim) |
| 714 | axis2 = _axis_fmt(axis2, "axis2", a.ndim) |
| 715 | |
| 716 | if axis1 == axis2: |
| 717 | raise ValueError("axis1 and axis2 cannot be the same") |
| 718 | |
| 719 | a = asarray(a) |
| 720 | k = offset |
| 721 | if axis1 > axis2: |
| 722 | axis1, axis2 = axis2, axis1 |
| 723 | k = -offset |
| 724 | |
| 725 | free_axes = set(range(a.ndim)) - {axis1, axis2} |
| 726 | free_indices = list(product(*(range(a.numblocks[i]) for i in free_axes))) |
| 727 | ndims_free = len(free_axes) |
| 728 | |
| 729 | # equation of diagonal: i = j - k |
| 730 | kdiag_row_start = max(0, -k) |
| 731 | kdiag_col_start = max(0, k) |
| 732 | kdiag_row_stop = min(a.shape[axis1], a.shape[axis2] - k) |
| 733 | len_kdiag = kdiag_row_stop - kdiag_row_start |
| 734 | |
| 735 | if len_kdiag <= 0: |
| 736 | xp = np |
| 737 | |
| 738 | if is_cupy_type(a._meta): |
| 739 | import cupy |
| 740 | |
| 741 | xp = cupy |
| 742 | |
| 743 | out_chunks = pop_axes(a.chunks, axis1, axis2) + ((0,),) |
| 744 | dsk = {} |
| 745 | for free_idx in free_indices: |
| 746 | shape = tuple( |
| 747 | out_chunks[axis][free_idx[axis]] for axis in range(ndims_free) |
| 748 | ) |
no test coverage detected
searching dependent graphs…