MCPcopy
hub / github.com/tinygrad/tinygrad / render_kernel

Method render_kernel

tinygrad/renderer/cstyle.py:440–468  ·  view source on GitHub ↗
(self, function_name, kernel, bufs, uops, prefix=None)

Source from the content-addressed store, hash-verified

438 return f"struct __align__({dt.itemsize}) {vec} {{ {scal} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
439
440 def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
441 # TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
442 prefix = ["#define INFINITY (__int_as_float(0x7f800000))", "#define NAN (__int_as_float(0x7fffffff))",
443 "template <class T, class F> __device__ __forceinline__ T tg_bitcast(F v) { union U { F f; T t; }; U u; u.f = v; return u.t; }"]
444 used_dtypes = uops_to_dtypes(uops)
445 if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): prefix.append("#include <cuda_fp8.h>")
446 if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include <cuda_fp16.h>")
447 if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
448 prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if (dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16})
449 or (dt.count in (2,4,8,16) and dt.scalar() in dtypes.fp8s)]
450 dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16", dtypes.fp8e4m3: "e4m3", dtypes.fp8e5m2: "e5m2" }
451 dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
452 for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops):
453 upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
454 wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
455 n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes
456 operands = [f"%{i}" for i in range(sum(n_operands))]
457
458 # mma operands => {c}, {a}, {b}, {c}
459 prefix.append(f"""__device__ {wmma_dtypes[2]} __{name}({wmma_dtypes[0]} a, {wmma_dtypes[1]} b, {wmma_dtypes[2]} c){{
460 int *a_pk = (int *)(&a), *b_pk = (int *)(&b), *c_pk = (int *)(&c);
461 asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}"
462 "{{{", ".join(operands[:n_operands[2]])}}}, {{{", ".join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}},"
463 "{{{", ".join(operands[-n_operands[1]:])}}}, {{{", ".join(operands[:n_operands[2]])}}};"
464 : {", ".join([f'"+r"(c_pk[{i}])' for i in range(n_operands[2])])}
465 : {", ".join([f'"r"(a_pk[{i}])' for i in range(n_operands[0])])}, {", ".join([f'"r"(b_pk[{i}])' for i in range(n_operands[1])])});
466 return c;\n}}""")
467
468 return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
469
470 def supported_dtypes(self):
471 ver = int(self.target.arch[3:])

Callers

nothing calls this directly

Calls 9

render_vector_prefixMethod · 0.95
prodFunction · 0.90
uops_to_dtypesFunction · 0.85
wmma_argsFunction · 0.85
scalarMethod · 0.80
appendMethod · 0.80
render_dtypeMethod · 0.45
vecMethod · 0.45
render_kernelMethod · 0.45

Tested by

no test coverage detected