Selects elements from `self` based on the boolean `mask`. With `size=None` (default), output length equals the number of `True` values (not jittable). With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable). ```python exec="true" source="above" sessio
(self, mask, size:int|None=None, fill_value:ConstType=0)
| 1040 | raise TypeError("Tensor does not support deleting items") |
| 1041 | |
| 1042 | def masked_select(self, mask, size:int|None=None, fill_value:ConstType=0): |
| 1043 | """ |
| 1044 | Selects elements from `self` based on the boolean `mask`. |
| 1045 | |
| 1046 | With `size=None` (default), output length equals the number of `True` values (not jittable). |
| 1047 | With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable). |
| 1048 | |
| 1049 | ```python exec="true" source="above" session="tensor" result="python" |
| 1050 | t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) |
| 1051 | mask = Tensor([[True, False, True], [False, True, False], [False, False, True]]) |
| 1052 | print(t.numpy()) |
| 1053 | print(mask.numpy()) |
| 1054 | ``` |
| 1055 | ```python exec="true" source="above" session="tensor" result="python" |
| 1056 | print(t.masked_select(mask).numpy()) |
| 1057 | ``` |
| 1058 | ```python exec="true" source="above" session="tensor" result="python" |
| 1059 | print(t.masked_select(mask, size=6, fill_value=-1).numpy()) |
| 1060 | ``` |
| 1061 | """ |
| 1062 | if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}") |
| 1063 | x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten() |
| 1064 | mask_cumsum = mask.cumsum() |
| 1065 | if size is None: |
| 1066 | counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, device=self.device, buffer=False) |
| 1067 | return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()] |
| 1068 | counts = Tensor.zeros(size, dtype=dtypes.int32, device=self.device, buffer=False).scatter(0, mask_cumsum, 1, reduce='add') |
| 1069 | return (Tensor.arange(size, device=self.device) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype) |
| 1070 | |
| 1071 | def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Tensor: |
| 1072 | """ |