(self, function_name, kernel, bufs, uops, prefix=None)
| 535 | f"{vec} make_{vec}({', '.join([f'{scal} {x}' for x in _nms[:dtype.count]])}) {{ return {{ {', '.join(_nms[:dtype.count])} }}; }}" |
| 536 | |
| 537 | def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: |
| 538 | prefix, ockl = [], [] |
| 539 | type_map = { dtypes.bfloat16: "bf16", dtypes.float: "f32", dtypes.half: "f16", dtypes.fp8e4m3: "_fp8_fp8", dtypes.fp8e5m2: "_bf8_bf8" } |
| 540 | used_dtypes = uops_to_dtypes(uops) |
| 541 | if any(u.op is Ops.CONST and not math.isfinite(u.arg) for u in uops): |
| 542 | prefix += ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))"] |
| 543 | if any(u.op is Ops.SPECIAL for u in uops): |
| 544 | prefix.append("typedef long unsigned int size_t;") |
| 545 | ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]] |
| 546 | ocml_ops = {Ops.EXP2: ("exp2", "pure"), Ops.LOG2: ("log2", "pure"), Ops.SQRT: ("sqrt", "const"), Ops.SIN: ("sin", ""), Ops.TRUNC: ("trunc", "")} |
| 547 | ocml = [(f"__ocml_{ocml_ops[op][0]}_f{dt.bitsize}", dt.name, dt.name, ocml_ops[op][1]) |
| 548 | for op, dt in dedup((u.op, u.dtype.scalar()) for u in uops) if op in ocml_ops and dt in (dtypes.half, dtypes.float, dtypes.double)] |
| 549 | if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): |
| 550 | prefix.append(f"typedef {'__bf16' if self.is_cdna4(self.target.arch) else 'unsigned short'} hip_bfloat16;") |
| 551 | if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#define half _Float16") |
| 552 | if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): |
| 553 | prefix += ["typedef unsigned char hip_bf8;", "typedef unsigned char hip_fp8;"] |
| 554 | if any(u.op is Ops.CAST and u.dtype in dtypes.fp8s and u.src[0].dtype == dtypes.float for u in uops): |
| 555 | prefix.append("""static inline __attribute__((device)) unsigned char f32_to_fp8(float v, int is_bf8) { |
| 556 | v = (((*(unsigned*)&v)&0x7F800000)!=0x7F800000)?__builtin_amdgcn_fmed3f(v,is_bf8?57344.0f:448.0f,is_bf8?-57344.0f:-448.0f) : v; |
| 557 | return (unsigned char)(is_bf8?__builtin_amdgcn_cvt_pk_bf8_f32(v,v,0,false):__builtin_amdgcn_cvt_pk_fp8_f32(v,v,0,false));\n}""") |
| 558 | prefix += [f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml] |
| 559 | prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1] |
| 560 | |
| 561 | for name, (N, M, K), dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper |
| 562 | if self.is_cdna(self.target.arch): |
| 563 | if (N, M, K) == (16, 16, 16): type_map[dtypes.bfloat16] = 'bf16_1k' |
| 564 | elif (N, M, K) == (16, 16, 32): type_map = {**type_map, dtypes.bfloat16: "_bf16", dtypes.half: "_f16"} |
| 565 | elif (N, M, K) == (16, 16, 128): type_map = {**type_map, dtypes.fp8e4m3: "_f8f6f4", dtypes.fp8e5m2: "_f8f6f4"} |
| 566 | prefix.append(f"#define __{name} __builtin_amdgcn_mfma_{'scale_' if K == 128 else ''}f32_{N}x{M}x{K}{type_map[dtype_in]}") |
| 567 | # #define __WMMA_16_16_16_half_half __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12 |
| 568 | elif self.tensor_cores == tc.amd_rdna4: |
| 569 | prefix.append(f"#define __{name} __builtin_amdgcn_wmma_{type_map[dtype_out]}_16x16x16_{type_map[dtype_in]}_w32_gfx12") |
| 570 | elif dtype_out == dtypes.float: |
| 571 | prefix.append(f"#define __{name} __builtin_amdgcn_wmma_f32_16x16x16_{'f16' if dtype_in == dtypes.half else 'bf16'}_w32") |
| 572 | else: prefix.append(f"static inline __attribute__((device)) half8 __{name}"+"""(half16 a, half16 b, half8 c) { |
| 573 | half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; } |
| 574 | c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false); |
| 575 | for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""") |
| 576 | return super().render_kernel(function_name, kernel, bufs, uops, prefix) |
| 577 | |
| 578 | def supported_dtypes(self): return {d for d in super().supported_dtypes() |
| 579 | if (d not in dtypes.fp8_ocp or self.target.arch == "gfx950") and d not in dtypes.fp8_fnuz} |
no test coverage detected