(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None)
| 132 | extra_matcher = extra_pm |
| 133 | |
| 134 | def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str: |
| 135 | tmp = "" |
| 136 | if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs): |
| 137 | tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" |
| 138 | buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else |
| 139 | self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs] |
| 140 | local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"] |
| 141 | launch_bounds = prod([d.vmax for d in local_dims]) |
| 142 | prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] + |
| 143 | [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + |
| 144 | [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) |
| 145 | return prg if prefix is None else "\n".join(prefix)+f"\n{prg}" |
| 146 | |
| 147 | def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})" |
| 148 | def render_dtype(self, dt:DType, mutable=True) -> str: |
no test coverage detected