(data: torch.Tensor, device="CUDA")
| 11 | def f(tg_out, tg_data): return tg_out.assign(tg_data[:, :, 0] * 0.2989 + tg_data[:, :, 1] * 0.5870 + tg_data[:, :, 2] * 0.1140).realize() |
| 12 | |
| 13 | def custom_kernel(data: torch.Tensor, device="CUDA") -> torch.Tensor: |
| 14 | assert data.dtype == torch.float32 |
| 15 | tg_data = Tensor.from_blob(data.data_ptr(), data.shape, dtype=_from_torch_dtype(data.dtype), device=device) |
| 16 | |
| 17 | out = torch.empty((data.shape[0], data.shape[1]), dtype=data.dtype, device=data.device) |
| 18 | tg_out = Tensor.from_blob(out.data_ptr(), out.shape, dtype=_from_torch_dtype(out.dtype), device=device) |
| 19 | |
| 20 | # Need to sync torch to make sure the data is valid. |
| 21 | if data.device.type == "mps": torch.mps.synchronize() |
| 22 | else: torch.cuda.synchronize() |
| 23 | |
| 24 | with Context(BEAM=2): f(tg_out, tg_data) |
| 25 | |
| 26 | # Wait for computation to finish and the data is valid. |
| 27 | Device[device].synchronize() |
| 28 | |
| 29 | return out |
| 30 | |
| 31 | if __name__ == "__main__": |
| 32 | for i in range(3): |
no test coverage detected
searching dependent graphs…