(self, function_name, kernel, bufs, uops, prefix=None)
| 379 | ]) + base_rewrite |
| 380 | |
| 381 | def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): |
| 382 | prefix = ["#include <metal_stdlib>","using namespace metal;"] |
| 383 | deduped_wmma_args = dedup([(name, dtype_in, dtype_out) for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops)]) |
| 384 | for name, dtype_in, dtype_out in deduped_wmma_args: prefix.append( |
| 385 | f"""{(dstr_out:=self.render_dtype(dtype_out.vec(2)))} __{name}({(dstr_in:=self.render_dtype(dtype_in.vec(2)))} a, {dstr_in} b, {dstr_out} c){{ |
| 386 | simdgroup_{self.render_dtype(dtype_in)}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(dtype_out)}8x8 mat_c; |
| 387 | mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0]; |
| 388 | mat_a.thread_elements()[1] = a[1]; mat_b.thread_elements()[1] = b[1]; mat_c.thread_elements()[1] = c[1]; |
| 389 | simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dstr_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""") |
| 390 | return super().render_kernel(function_name, kernel, bufs, uops, prefix) |
| 391 | |
| 392 | def supported_dtypes(self): |
| 393 | return {d for d in super().supported_dtypes() if (d != dtypes.bfloat16 or ((arch:=self.target.arch).startswith("Apple") and int(arch[5:]) >= 6)) |
nothing calls this directly
no test coverage detected