A reduction to count the number of elements. This has an additional kwarg in coerce_np_ndarray, which determines whether to ensure that the resulting array is a numpy.ndarray, or whether we allow it to be other array types via `np.full_like`.
(x, coerce_np_ndarray: bool, **kwargs)
| 277 | |
| 278 | |
| 279 | def _numel(x, coerce_np_ndarray: bool, **kwargs): |
| 280 | """ |
| 281 | A reduction to count the number of elements. |
| 282 | |
| 283 | This has an additional kwarg in coerce_np_ndarray, which determines |
| 284 | whether to ensure that the resulting array is a numpy.ndarray, or whether |
| 285 | we allow it to be other array types via `np.full_like`. |
| 286 | """ |
| 287 | shape = x.shape |
| 288 | keepdims = kwargs.get("keepdims", False) |
| 289 | axis = kwargs.get("axis") |
| 290 | dtype = kwargs.get("dtype", np.float64) |
| 291 | |
| 292 | if axis is None: |
| 293 | prod = np.prod(shape, dtype=dtype) |
| 294 | if keepdims is False: |
| 295 | return prod |
| 296 | |
| 297 | if coerce_np_ndarray: |
| 298 | return np.full(shape=(1,) * len(shape), fill_value=prod, dtype=dtype) |
| 299 | else: |
| 300 | return np.full_like(x, prod, shape=(1,) * len(shape), dtype=dtype) |
| 301 | |
| 302 | if not isinstance(axis, (tuple, list)): |
| 303 | axis = [axis] |
| 304 | |
| 305 | prod = math.prod(shape[dim] for dim in axis) |
| 306 | if keepdims is True: |
| 307 | new_shape = tuple( |
| 308 | shape[dim] if dim not in axis else 1 for dim in range(len(shape)) |
| 309 | ) |
| 310 | else: |
| 311 | new_shape = tuple(shape[dim] for dim in range(len(shape)) if dim not in axis) |
| 312 | |
| 313 | if coerce_np_ndarray: |
| 314 | return np.broadcast_to(np.array(prod, dtype=dtype), new_shape) |
| 315 | else: |
| 316 | return np.full_like(x, prod, shape=new_shape, dtype=dtype) |
| 317 | |
| 318 | |
| 319 | @nannumel_lookup.register((object, np.ndarray)) |
no test coverage detected
searching dependent graphs…