(v, k=0)
| 626 | |
| 627 | @derived_from(np) |
| 628 | def diag(v, k=0): |
| 629 | if not isinstance(v, np.ndarray) and not isinstance(v, Array): |
| 630 | raise TypeError(f"v must be a dask array or numpy array, got {type(v)}") |
| 631 | |
| 632 | name = "diag-" + tokenize(v, k) |
| 633 | |
| 634 | meta = meta_from_array(v, 2 if v.ndim == 1 else 1) |
| 635 | |
| 636 | if isinstance(v, np.ndarray) or ( |
| 637 | hasattr(v, "__array_function__") and not isinstance(v, Array) |
| 638 | ): |
| 639 | if v.ndim == 1: |
| 640 | m = abs(k) |
| 641 | chunks = ((v.shape[0] + m,), (v.shape[0] + m,)) |
| 642 | key = (name, 0, 0) |
| 643 | dsk = {key: Task(key, np.diag, v, k)} |
| 644 | elif v.ndim == 2: |
| 645 | kdiag_row_start = max(0, -k) |
| 646 | kdiag_row_stop = min(v.shape[0], v.shape[1] - k) |
| 647 | len_kdiag = kdiag_row_stop - kdiag_row_start |
| 648 | chunks = ((0,),) if len_kdiag <= 0 else ((len_kdiag,),) |
| 649 | key = (name, 0) |
| 650 | dsk = {key: Task(key, np.diag, v, k)} |
| 651 | else: |
| 652 | raise ValueError("Array must be 1d or 2d only") |
| 653 | return Array(dsk, name, chunks, meta=meta) |
| 654 | |
| 655 | if v.ndim != 1: |
| 656 | if v.ndim != 2: |
| 657 | raise ValueError("Array must be 1d or 2d only") |
| 658 | if k == 0 and v.chunks[0] == v.chunks[1]: |
| 659 | tasks = [ |
| 660 | Task((name, i), np.diag, TaskRef(row[i])) |
| 661 | for i, row in enumerate(v.__dask_keys__()) |
| 662 | ] |
| 663 | dsk = {t.key: t for t in tasks} |
| 664 | graph = HighLevelGraph.from_collections(name, dsk, dependencies=[v]) |
| 665 | return Array(graph, name, (v.chunks[0],), meta=meta) |
| 666 | else: |
| 667 | return diagonal(v, k) |
| 668 | |
| 669 | if k == 0: |
| 670 | chunks_1d = v.chunks[0] |
| 671 | blocks = v.__dask_keys__() |
| 672 | dsk = {} |
| 673 | for i, m in enumerate(chunks_1d): |
| 674 | for j, n in enumerate(chunks_1d): |
| 675 | key = (name, i, j) |
| 676 | if i == j: |
| 677 | dsk[key] = Task(key, np.diag, TaskRef(blocks[i])) |
| 678 | else: |
| 679 | dsk[key] = Task(key, np.zeros_like, meta, shape=(m, n)) |
| 680 | |
| 681 | graph = HighLevelGraph.from_collections(name, dsk, dependencies=[v]) |
| 682 | return Array(graph, name, (chunks_1d, chunks_1d), meta=meta) |
| 683 | |
| 684 | elif k > 0: |
| 685 | return pad(diag(v), [[0, k], [k, 0]], mode="constant") |
no test coverage detected
searching dependent graphs…