MCPcopy
hub / github.com/tinygrad/tinygrad / masked_select

Method masked_select

tinygrad/tensor.py:1042–1069  ·  view source on GitHub ↗

Selects elements from `self` based on the boolean `mask`. With `size=None` (default), output length equals the number of `True` values (not jittable). With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable). ```python exec="true" source="above" sessio

(self, mask, size:int|None=None, fill_value:ConstType=0)

Source from the content-addressed store, hash-verified

1040 raise TypeError("Tensor does not support deleting items")
1041
1042 def masked_select(self, mask, size:int|None=None, fill_value:ConstType=0):
1043 """
1044 Selects elements from `self` based on the boolean `mask`.
1045
1046 With `size=None` (default), output length equals the number of `True` values (not jittable).
1047 With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable).
1048
1049 ```python exec="true" source="above" session="tensor" result="python"
1050 t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
1051 mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
1052 print(t.numpy())
1053 print(mask.numpy())
1054 ```
1055 ```python exec="true" source="above" session="tensor" result="python"
1056 print(t.masked_select(mask).numpy())
1057 ```
1058 ```python exec="true" source="above" session="tensor" result="python"
1059 print(t.masked_select(mask, size=6, fill_value=-1).numpy())
1060 ```
1061 """
1062 if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
1063 x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
1064 mask_cumsum = mask.cumsum()
1065 if size is None:
1066 counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, device=self.device, buffer=False)
1067 return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
1068 counts = Tensor.zeros(size, dtype=dtypes.int32, device=self.device, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
1069 return (Tensor.arange(size, device=self.device) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
1070
1071 def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Tensor:
1072 """

Callers 7

nonzeroMethod · 0.80
NonZeroFunction · 0.80
test_masked_selectMethod · 0.80
fMethod · 0.80
test_masked_selectMethod · 0.80

Calls 12

is_boolMethod · 0.80
flattenMethod · 0.80
_broadcast_toMethod · 0.80
cumsumMethod · 0.80
zerosMethod · 0.80
itemMethod · 0.80
numelMethod · 0.80
scatterMethod · 0.80
arangeMethod · 0.80
sumMethod · 0.80
castMethod · 0.45
whereMethod · 0.45

Tested by 5

test_masked_selectMethod · 0.64
fMethod · 0.64
test_masked_selectMethod · 0.64