(self, function_name, kernel, bufs, uops, prefix=None)
| 438 | return f"struct __align__({dt.itemsize}) {vec} {{ {scal} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}" |
| 439 | |
| 440 | def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): |
| 441 | # TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name |
| 442 | prefix = ["#define INFINITY (__int_as_float(0x7f800000))", "#define NAN (__int_as_float(0x7fffffff))", |
| 443 | "template <class T, class F> __device__ __forceinline__ T tg_bitcast(F v) { union U { F f; T t; }; U u; u.f = v; return u.t; }"] |
| 444 | used_dtypes = uops_to_dtypes(uops) |
| 445 | if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): prefix.append("#include <cuda_fp8.h>") |
| 446 | if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include <cuda_fp16.h>") |
| 447 | if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>") |
| 448 | prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if (dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}) |
| 449 | or (dt.count in (2,4,8,16) and dt.scalar() in dtypes.fp8s)] |
| 450 | dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16", dtypes.fp8e4m3: "e4m3", dtypes.fp8e5m2: "e5m2" } |
| 451 | dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" } |
| 452 | for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops): |
| 453 | upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes] |
| 454 | wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] |
| 455 | n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes |
| 456 | operands = [f"%{i}" for i in range(sum(n_operands))] |
| 457 | |
| 458 | # mma operands => {c}, {a}, {b}, {c} |
| 459 | prefix.append(f"""__device__ {wmma_dtypes[2]} __{name}({wmma_dtypes[0]} a, {wmma_dtypes[1]} b, {wmma_dtypes[2]} c){{ |
| 460 | int *a_pk = (int *)(&a), *b_pk = (int *)(&b), *c_pk = (int *)(&c); |
| 461 | asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}" |
| 462 | "{{{", ".join(operands[:n_operands[2]])}}}, {{{", ".join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}}," |
| 463 | "{{{", ".join(operands[-n_operands[1]:])}}}, {{{", ".join(operands[:n_operands[2]])}}};" |
| 464 | : {", ".join([f'"+r"(c_pk[{i}])' for i in range(n_operands[2])])} |
| 465 | : {", ".join([f'"r"(a_pk[{i}])' for i in range(n_operands[0])])}, {", ".join([f'"r"(b_pk[{i}])' for i in range(n_operands[1])])}); |
| 466 | return c;\n}}""") |
| 467 | |
| 468 | return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix) |
| 469 | |
| 470 | def supported_dtypes(self): |
| 471 | ver = int(self.target.arch[3:]) |
nothing calls this directly
no test coverage detected