MCPcopy
hub / github.com/tinygrad/tinygrad / render_kernel

Method render_kernel

tinygrad/renderer/cstyle.py:537–576  ·  view source on GitHub ↗
(self, function_name, kernel, bufs, uops, prefix=None)

Source from the content-addressed store, hash-verified

535 f"{vec} make_{vec}({', '.join([f'{scal} {x}' for x in _nms[:dtype.count]])}) {{ return {{ {', '.join(_nms[:dtype.count])} }}; }}"
536
537 def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
538 prefix, ockl = [], []
539 type_map = { dtypes.bfloat16: "bf16", dtypes.float: "f32", dtypes.half: "f16", dtypes.fp8e4m3: "_fp8_fp8", dtypes.fp8e5m2: "_bf8_bf8" }
540 used_dtypes = uops_to_dtypes(uops)
541 if any(u.op is Ops.CONST and not math.isfinite(u.arg) for u in uops):
542 prefix += ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))"]
543 if any(u.op is Ops.SPECIAL for u in uops):
544 prefix.append("typedef long unsigned int size_t;")
545 ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
546 ocml_ops = {Ops.EXP2: ("exp2", "pure"), Ops.LOG2: ("log2", "pure"), Ops.SQRT: ("sqrt", "const"), Ops.SIN: ("sin", ""), Ops.TRUNC: ("trunc", "")}
547 ocml = [(f"__ocml_{ocml_ops[op][0]}_f{dt.bitsize}", dt.name, dt.name, ocml_ops[op][1])
548 for op, dt in dedup((u.op, u.dtype.scalar()) for u in uops) if op in ocml_ops and dt in (dtypes.half, dtypes.float, dtypes.double)]
549 if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes):
550 prefix.append(f"typedef {'__bf16' if self.is_cdna4(self.target.arch) else 'unsigned short'} hip_bfloat16;")
551 if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#define half _Float16")
552 if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes):
553 prefix += ["typedef unsigned char hip_bf8;", "typedef unsigned char hip_fp8;"]
554 if any(u.op is Ops.CAST and u.dtype in dtypes.fp8s and u.src[0].dtype == dtypes.float for u in uops):
555 prefix.append("""static inline __attribute__((device)) unsigned char f32_to_fp8(float v, int is_bf8) {
556 v = (((*(unsigned*)&v)&0x7F800000)!=0x7F800000)?__builtin_amdgcn_fmed3f(v,is_bf8?57344.0f:448.0f,is_bf8?-57344.0f:-448.0f) : v;
557 return (unsigned char)(is_bf8?__builtin_amdgcn_cvt_pk_bf8_f32(v,v,0,false):__builtin_amdgcn_cvt_pk_fp8_f32(v,v,0,false));\n}""")
558 prefix += [f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml]
559 prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
560
561 for name, (N, M, K), dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
562 if self.is_cdna(self.target.arch):
563 if (N, M, K) == (16, 16, 16): type_map[dtypes.bfloat16] = 'bf16_1k'
564 elif (N, M, K) == (16, 16, 32): type_map = {**type_map, dtypes.bfloat16: "_bf16", dtypes.half: "_f16"}
565 elif (N, M, K) == (16, 16, 128): type_map = {**type_map, dtypes.fp8e4m3: "_f8f6f4", dtypes.fp8e5m2: "_f8f6f4"}
566 prefix.append(f"#define __{name} __builtin_amdgcn_mfma_{'scale_' if K == 128 else ''}f32_{N}x{M}x{K}{type_map[dtype_in]}")
567 # #define __WMMA_16_16_16_half_half __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12
568 elif self.tensor_cores == tc.amd_rdna4:
569 prefix.append(f"#define __{name} __builtin_amdgcn_wmma_{type_map[dtype_out]}_16x16x16_{type_map[dtype_in]}_w32_gfx12")
570 elif dtype_out == dtypes.float:
571 prefix.append(f"#define __{name} __builtin_amdgcn_wmma_f32_16x16x16_{'f16' if dtype_in == dtypes.half else 'bf16'}_w32")
572 else: prefix.append(f"static inline __attribute__((device)) half8 __{name}"+"""(half16 a, half16 b, half8 c) {
573 half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
574 c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
575 for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
576 return super().render_kernel(function_name, kernel, bufs, uops, prefix)
577
578 def supported_dtypes(self): return {d for d in super().supported_dtypes()
579 if (d not in dtypes.fp8_ocp or self.target.arch == "gfx950") and d not in dtypes.fp8_fnuz}

Callers 5

_render_bodyMethod · 0.45
render_kernelMethod · 0.45
render_kernelMethod · 0.45
render_kernelMethod · 0.45
render_kernelMethod · 0.45

Calls 9

is_cdna4Method · 0.95
render_vector_prefixMethod · 0.95
is_cdnaMethod · 0.95
dedupFunction · 0.90
uops_to_dtypesFunction · 0.85
wmma_argsFunction · 0.85
isfiniteMethod · 0.80
appendMethod · 0.80
scalarMethod · 0.80

Tested by

no test coverage detected