()
| 440 | THREADS = 128 |
| 441 | |
| 442 | def test_matmul(): |
| 443 | dev = Device[Device.DEFAULT] |
| 444 | print(f"Device arch: {dev.renderer.target.arch}") |
| 445 | |
| 446 | insts = build_kernel(N) |
| 447 | |
| 448 | rng = np.random.default_rng(42) |
| 449 | a = Tensor(rng.random((N, N), dtype=np.float32) - 0.5) |
| 450 | b = Tensor(rng.random((N, N), dtype=np.float32) - 0.5) |
| 451 | c = Tensor.empty(N, N) |
| 452 | Tensor.realize(a, b, c) |
| 453 | |
| 454 | grid, local = (N // BLOCK_N, N // BLOCK_M, 1), (THREADS, 1, 1) |
| 455 | print(f"Grid: {grid}, Local: {local}") |
| 456 | |
| 457 | dname:str = Device.DEFAULT |
| 458 | def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp: |
| 459 | gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)] |
| 460 | lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)] |
| 461 | lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds') |
| 462 | sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"), |
| 463 | estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3))) |
| 464 | return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts])))) |
| 465 | c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2] |
| 466 | linear = c.schedule_linear() |
| 467 | |
| 468 | ets = [] |
| 469 | with Context(DEBUG=2): |
| 470 | for _ in range(getenv("CNT", 5)): |
| 471 | start = GlobalCounters.time_sum_s |
| 472 | run_linear(linear) |
| 473 | ets.append(GlobalCounters.time_sum_s - start) |
| 474 | print(f"REAL TFLOPS {N * N * N * 2 / min(ets) * 1e-12:.2f}") |
| 475 | |
| 476 | if getenv("VERIFY", 1): |
| 477 | GlobalCounters.reset() |
| 478 | with Context(DEBUG=2): tc = (a @ b).realize() |
| 479 | with Context(DEBUG=0): err = (c - tc).square().mean().item() |
| 480 | print(f"mean squared error {err}") |
| 481 | if err != err or err > 1e-06: |
| 482 | c_np, tc_np = c.numpy(), tc.numpy() |
| 483 | for bi in range(N // 128): |
| 484 | for bj in range(N // 128): |
| 485 | blk_c = c_np[bi*128:(bi+1)*128, bj*128:(bj+1)*128] |
| 486 | blk_ref = tc_np[bi*128:(bi+1)*128, bj*128:(bj+1)*128] |
| 487 | blk_diff = blk_c - blk_ref |
| 488 | zero_rows = [i for i in range(128) if np.all(np.abs(blk_c[i,:]) < 1e-10)] |
| 489 | nz_rows = [i for i in range(128) if i not in zero_rows] |
| 490 | nz_mse = float(np.mean(blk_diff[nz_rows,:]**2)) if nz_rows else 0 |
| 491 | print(f"Block ({bi},{bj}): zero_rows={zero_rows}, nz_rows_mse={nz_mse:.2e}") |
| 492 | # show first few non-zero row comparisons |
| 493 | if nz_rows and nz_mse > 1e-6: |
| 494 | for r in nz_rows[:3]: |
| 495 | print(f" row {r} asm[0:8]: {blk_c[r,:8]}") |
| 496 | print(f" row {r} ref[0:8]: {blk_ref[r,:8]}") |
| 497 | raise RuntimeError("matmul is wrong!") |
| 498 | |
| 499 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…