()
| 201 | N = getenv("N", 4096) |
| 202 | |
| 203 | def test_matmul(): |
| 204 | dev = Device[Device.DEFAULT] |
| 205 | arch = getattr(dev.renderer, 'arch', 'gfx1200') |
| 206 | print(f"Device arch: {arch}") |
| 207 | insts = build_kernel(N, arch) |
| 208 | |
| 209 | rng = np.random.default_rng(42) |
| 210 | a = Tensor(rng.random((N, N), dtype=np.float32).astype(np.float16)) |
| 211 | b = Tensor(rng.random((N, N), dtype=np.float32).astype(np.float16)) |
| 212 | c = Tensor.empty(N, N, dtype=dtypes.half) |
| 213 | Tensor.realize(a, b, c) |
| 214 | |
| 215 | grid, local = (N//BLOCK_N, N//BLOCK_M, 1), (THREADS, 1, 1) |
| 216 | print(f"Grid: {grid}, Local: {local}") |
| 217 | |
| 218 | dname = Device.DEFAULT |
| 219 | def asm_kernel(A, B, C): |
| 220 | gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)] |
| 221 | lidxs = [UOp.special(THREADS, "lidx0")] |
| 222 | lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2)), addrspace=AddrSpace.LOCAL), (), 'lds') |
| 223 | sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, |
| 224 | arg=KernelInfo(name=colored("kernel","cyan"), estimates=Estimates(ops=N*N*N*2, mem=N*N*2*3))) |
| 225 | return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts])))) |
| 226 | |
| 227 | c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2] |
| 228 | linear = c.schedule_linear() |
| 229 | |
| 230 | ets = [] |
| 231 | with Context(DEBUG=2): |
| 232 | for _ in range(getenv("CNT", 5)): |
| 233 | start = GlobalCounters.time_sum_s |
| 234 | run_linear(linear) |
| 235 | ets.append(GlobalCounters.time_sum_s - start) |
| 236 | print(f"REAL TFLOPS {N*N*N*2 / min(ets) * 1e-12:.2f}") |
| 237 | |
| 238 | if getenv("VERIFY", 1): |
| 239 | GlobalCounters.reset() |
| 240 | c_np = c.float().numpy() |
| 241 | a_np, b_np = a.float().numpy(), b.float().numpy() |
| 242 | ref = a_np @ b_np |
| 243 | err = np.sqrt(np.mean((c_np - ref)**2)) / np.sqrt(np.mean(ref**2)) |
| 244 | print(f"relative RMSE {err:.6f}") |
| 245 | if err != err or err > 0.05: raise RuntimeError(f"matmul is wrong! RMSE={err}") |
| 246 | |
| 247 | if __name__ == "__main__": |
| 248 | test_matmul() |
no test coverage detected
searching dependent graphs…