(a, axis=None)
| 1989 | |
| 1990 | @derived_from(np) |
| 1991 | def squeeze(a, axis=None): |
| 1992 | if axis is None: |
| 1993 | axis = tuple(i for i, d in enumerate(a.shape) if d == 1) |
| 1994 | elif not isinstance(axis, tuple): |
| 1995 | axis = (axis,) |
| 1996 | |
| 1997 | if any(a.shape[i] != 1 for i in axis): |
| 1998 | raise ValueError("cannot squeeze axis with size other than one") |
| 1999 | |
| 2000 | axis = validate_axis(axis, a.ndim) |
| 2001 | |
| 2002 | sl = tuple(0 if i in axis else slice(None) for i, s in enumerate(a.shape)) |
| 2003 | |
| 2004 | # Return 0d Dask Array if all axes are squeezed, |
| 2005 | # to be consistent with NumPy. Ref: https://github.com/dask/dask/issues/9183#issuecomment-1155626619 |
| 2006 | if all(s == 0 for s in sl) and all(s == 1 for s in a.shape): |
| 2007 | return a.map_blocks( |
| 2008 | np.squeeze, meta=a._meta, drop_axis=tuple(range(len(a.shape))) |
| 2009 | ) |
| 2010 | |
| 2011 | a = a[sl] |
| 2012 | |
| 2013 | return a |
| 2014 | |
| 2015 | |
| 2016 | @derived_from(np) |
no test coverage detected
searching dependent graphs…