MCPcopy
hub / github.com/tinygrad/tinygrad / test_matmul

Function test_matmul

extra/gemm/amd_asm_matmul.py:442–497  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

440THREADS = 128
441
442def test_matmul():
443 dev = Device[Device.DEFAULT]
444 print(f"Device arch: {dev.renderer.target.arch}")
445
446 insts = build_kernel(N)
447
448 rng = np.random.default_rng(42)
449 a = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
450 b = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
451 c = Tensor.empty(N, N)
452 Tensor.realize(a, b, c)
453
454 grid, local = (N // BLOCK_N, N // BLOCK_M, 1), (THREADS, 1, 1)
455 print(f"Grid: {grid}, Local: {local}")
456
457 dname:str = Device.DEFAULT
458 def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
459 gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
460 lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
461 lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds')
462 sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
463 estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
464 return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
465 c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2]
466 linear = c.schedule_linear()
467
468 ets = []
469 with Context(DEBUG=2):
470 for _ in range(getenv("CNT", 5)):
471 start = GlobalCounters.time_sum_s
472 run_linear(linear)
473 ets.append(GlobalCounters.time_sum_s - start)
474 print(f"REAL TFLOPS {N * N * N * 2 / min(ets) * 1e-12:.2f}")
475
476 if getenv("VERIFY", 1):
477 GlobalCounters.reset()
478 with Context(DEBUG=2): tc = (a @ b).realize()
479 with Context(DEBUG=0): err = (c - tc).square().mean().item()
480 print(f"mean squared error {err}")
481 if err != err or err > 1e-06:
482 c_np, tc_np = c.numpy(), tc.numpy()
483 for bi in range(N // 128):
484 for bj in range(N // 128):
485 blk_c = c_np[bi*128:(bi+1)*128, bj*128:(bj+1)*128]
486 blk_ref = tc_np[bi*128:(bi+1)*128, bj*128:(bj+1)*128]
487 blk_diff = blk_c - blk_ref
488 zero_rows = [i for i in range(128) if np.all(np.abs(blk_c[i,:]) < 1e-10)]
489 nz_rows = [i for i in range(128) if i not in zero_rows]
490 nz_mse = float(np.mean(blk_diff[nz_rows,:]**2)) if nz_rows else 0
491 print(f"Block ({bi},{bj}): zero_rows={zero_rows}, nz_rows_mse={nz_mse:.2e}")
492 # show first few non-zero row comparisons
493 if nz_rows and nz_mse > 1e-6:
494 for r in nz_rows[:3]:
495 print(f" row {r} asm[0:8]: {blk_c[r,:8]}")
496 print(f" row {r} ref[0:8]: {blk_ref[r,:8]}")
497 raise RuntimeError("matmul is wrong!")
498
499if __name__ == "__main__":

Callers 2

amd_asm_matmul.pyFile · 0.70

Calls 15

TensorClass · 0.90
ContextClass · 0.90
getenvFunction · 0.90
run_linearFunction · 0.90
realizeMethod · 0.80
schedule_linearMethod · 0.80
appendMethod · 0.80
itemMethod · 0.80
meanMethod · 0.80
squareMethod · 0.80
allMethod · 0.80
absMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…