MCPcopy
hub / github.com/tinygrad/tinygrad / custom_kernel

Function custom_kernel

examples/torch_cuda_kernel.py:13–29  ·  view source on GitHub ↗
(data: torch.Tensor, device="CUDA")

Source from the content-addressed store, hash-verified

11def f(tg_out, tg_data): return tg_out.assign(tg_data[:, :, 0] * 0.2989 + tg_data[:, :, 1] * 0.5870 + tg_data[:, :, 2] * 0.1140).realize()
12
13def custom_kernel(data: torch.Tensor, device="CUDA") -> torch.Tensor:
14 assert data.dtype == torch.float32
15 tg_data = Tensor.from_blob(data.data_ptr(), data.shape, dtype=_from_torch_dtype(data.dtype), device=device)
16
17 out = torch.empty((data.shape[0], data.shape[1]), dtype=data.dtype, device=data.device)
18 tg_out = Tensor.from_blob(out.data_ptr(), out.shape, dtype=_from_torch_dtype(out.dtype), device=device)
19
20 # Need to sync torch to make sure the data is valid.
21 if data.device.type == "mps": torch.mps.synchronize()
22 else: torch.cuda.synchronize()
23
24 with Context(BEAM=2): f(tg_out, tg_data)
25
26 # Wait for computation to finish and the data is valid.
27 Device[device].synchronize()
28
29 return out
30
31if __name__ == "__main__":
32 for i in range(3):

Callers 1

Calls 6

_from_torch_dtypeFunction · 0.90
ContextClass · 0.90
from_blobMethod · 0.80
fFunction · 0.70
emptyMethod · 0.45
synchronizeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…