MCPcopy
hub / github.com/tinygrad/tinygrad / where

Method where

tinygrad/tensor.py:1275–1297  ·  view source on GitHub ↗

Returns a tensor of elements selected from either `x` or `y`, depending on `self`. `output_i = x_i if self_i else y_i`. ```python exec="true" source="above" session="tensor" result="python" cond = Tensor([[True, True, False], [True, False, False]]) print(cond.where(1, 3).numpy(

(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint)

Source from the content-addressed store, hash-verified

1273 return Tensor(x, self.device, self.dtype if self._ufix_keep_dtype(x) else None)
1274
1275 def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
1276 """
1277 Returns a tensor of elements selected from either `x` or `y`, depending on `self`.
1278 `output_i = x_i if self_i else y_i`.
1279
1280 ```python exec="true" source="above" session="tensor" result="python"
1281 cond = Tensor([[True, True, False], [True, False, False]])
1282 print(cond.where(1, 3).numpy())
1283 ```
1284 ```python exec="true" source="above" session="tensor" result="python"
1285 Tensor.manual_seed(42)
1286 cond = Tensor.randn(2, 3)
1287 print(cond.numpy())
1288 ```
1289 ```python exec="true" source="above" session="tensor" result="python"
1290 print((cond > 0).where(cond, -float("inf")).numpy())
1291 ```
1292 """
1293 if isinstance(x, Tensor): x, y = x._broadcasted(y)
1294 elif isinstance(y, Tensor): y, x = y._broadcasted(x)
1295 else: x, y = Tensor(x, self.device)._broadcasted(y)
1296 out_shape = _broadcast_shape(self.shape, x.shape)
1297 return self.cast(dtypes.bool)._broadcast_to(out_shape)._apply_uop(UOp.where, x._broadcast_to(out_shape), y._broadcast_to(out_shape))
1298
1299 # ***** op wrappers *****
1300

Callers 15

test_opsMethod · 0.95
gradient.pyFile · 0.45
_getitemMethod · 0.45
masked_selectMethod · 0.45
dropoutMethod · 0.45
ipadMethod · 0.45
u16_to_f16Function · 0.45
get_lrMethod · 0.45
fuzz_matmul.pyFile · 0.45
max_matmul.pyFile · 0.45
hip_matmul.pyFile · 0.45

Calls 6

castMethod · 0.95
_broadcast_shapeFunction · 0.90
TensorClass · 0.85
_apply_uopMethod · 0.80
_broadcast_toMethod · 0.80
_broadcastedMethod · 0.45

Tested by 15

test_opsMethod · 0.76
test_gated_loadMethod · 0.36
test_where_removalMethod · 0.36
test_where_combineMethod · 0.36
test_where_castMethod · 0.36
test_invalid_times_0Method · 0.36