| 360 | return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) |
| 361 | |
| 362 | def any(x: Array, |
| 363 | /, |
| 364 | *, |
| 365 | axis: int | tuple[int, ...] | None = None, |
| 366 | keepdims: bool = False, |
| 367 | **kwargs: object) -> Array: |
| 368 | |
| 369 | if axis == (): |
| 370 | return x.to(torch.bool) |
| 371 | # torch.any doesn't support multiple axes |
| 372 | # (https://github.com/pytorch/pytorch/issues/56586). |
| 373 | if isinstance(axis, tuple): |
| 374 | res = _reduce_multiple_axes(torch.any, x, axis, keepdims=keepdims, **kwargs) |
| 375 | return res.to(torch.bool) |
| 376 | if axis is None: |
| 377 | # torch doesn't support keepdims with axis=None |
| 378 | # (https://github.com/pytorch/pytorch/issues/71209) |
| 379 | res = torch.any(x, **kwargs) |
| 380 | res = _axis_none_keepdims(res, x.ndim, keepdims) |
| 381 | return res.to(torch.bool) |
| 382 | |
| 383 | # torch.any doesn't return bool for uint8 |
| 384 | return torch.any(x, axis, keepdims=keepdims).to(torch.bool) |
| 385 | |
| 386 | def all(x: Array, |
| 387 | /, |