MCPcopy
hub / github.com/tinygrad/tinygrad / render_kernel

Method render_kernel

tinygrad/renderer/cstyle.py:381–390  ·  view source on GitHub ↗
(self, function_name, kernel, bufs, uops, prefix=None)

Source from the content-addressed store, hash-verified

379 ]) + base_rewrite
380
381 def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
382 prefix = ["#include <metal_stdlib>","using namespace metal;"]
383 deduped_wmma_args = dedup([(name, dtype_in, dtype_out) for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops)])
384 for name, dtype_in, dtype_out in deduped_wmma_args: prefix.append(
385 f"""{(dstr_out:=self.render_dtype(dtype_out.vec(2)))} __{name}({(dstr_in:=self.render_dtype(dtype_in.vec(2)))} a, {dstr_in} b, {dstr_out} c){{
386 simdgroup_{self.render_dtype(dtype_in)}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(dtype_out)}8x8 mat_c;
387 mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0];
388 mat_a.thread_elements()[1] = a[1]; mat_b.thread_elements()[1] = b[1]; mat_c.thread_elements()[1] = c[1];
389 simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dstr_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
390 return super().render_kernel(function_name, kernel, bufs, uops, prefix)
391
392 def supported_dtypes(self):
393 return {d for d in super().supported_dtypes() if (d != dtypes.bfloat16 or ((arch:=self.target.arch).startswith("Apple") and int(arch[5:]) >= 6))

Callers

nothing calls this directly

Calls 6

dedupFunction · 0.90
wmma_argsFunction · 0.85
appendMethod · 0.80
render_dtypeMethod · 0.45
vecMethod · 0.45
render_kernelMethod · 0.45

Tested by

no test coverage detected