MCPcopy Index your code
hub / github.com/dask/dask / diag

Function diag

dask/array/creation.py:628–687  ·  view source on GitHub ↗
(v, k=0)

Source from the content-addressed store, hash-verified

626
627@derived_from(np)
628def diag(v, k=0):
629 if not isinstance(v, np.ndarray) and not isinstance(v, Array):
630 raise TypeError(f"v must be a dask array or numpy array, got {type(v)}")
631
632 name = "diag-" + tokenize(v, k)
633
634 meta = meta_from_array(v, 2 if v.ndim == 1 else 1)
635
636 if isinstance(v, np.ndarray) or (
637 hasattr(v, "__array_function__") and not isinstance(v, Array)
638 ):
639 if v.ndim == 1:
640 m = abs(k)
641 chunks = ((v.shape[0] + m,), (v.shape[0] + m,))
642 key = (name, 0, 0)
643 dsk = {key: Task(key, np.diag, v, k)}
644 elif v.ndim == 2:
645 kdiag_row_start = max(0, -k)
646 kdiag_row_stop = min(v.shape[0], v.shape[1] - k)
647 len_kdiag = kdiag_row_stop - kdiag_row_start
648 chunks = ((0,),) if len_kdiag <= 0 else ((len_kdiag,),)
649 key = (name, 0)
650 dsk = {key: Task(key, np.diag, v, k)}
651 else:
652 raise ValueError("Array must be 1d or 2d only")
653 return Array(dsk, name, chunks, meta=meta)
654
655 if v.ndim != 1:
656 if v.ndim != 2:
657 raise ValueError("Array must be 1d or 2d only")
658 if k == 0 and v.chunks[0] == v.chunks[1]:
659 tasks = [
660 Task((name, i), np.diag, TaskRef(row[i]))
661 for i, row in enumerate(v.__dask_keys__())
662 ]
663 dsk = {t.key: t for t in tasks}
664 graph = HighLevelGraph.from_collections(name, dsk, dependencies=[v])
665 return Array(graph, name, (v.chunks[0],), meta=meta)
666 else:
667 return diagonal(v, k)
668
669 if k == 0:
670 chunks_1d = v.chunks[0]
671 blocks = v.__dask_keys__()
672 dsk = {}
673 for i, m in enumerate(chunks_1d):
674 for j, n in enumerate(chunks_1d):
675 key = (name, i, j)
676 if i == j:
677 dsk[key] = Task(key, np.diag, TaskRef(blocks[i]))
678 else:
679 dsk[key] = Task(key, np.zeros_like, meta, shape=(m, n))
680
681 graph = HighLevelGraph.from_collections(name, dsk, dependencies=[v])
682 return Array(graph, name, (chunks_1d, chunks_1d), meta=meta)
683
684 elif k > 0:
685 return pad(diag(v), [[0, k], [k, 0]], mode="constant")

Callers 1

corrcoefFunction · 0.90

Calls 11

meta_from_arrayFunction · 0.90
TaskClass · 0.90
ArrayClass · 0.90
TaskRefClass · 0.90
maxFunction · 0.85
minFunction · 0.85
diagonalFunction · 0.85
padFunction · 0.85
from_collectionsMethod · 0.80
tokenizeFunction · 0.50
__dask_keys__Method · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…