MCPcopy
hub / github.com/tinygrad/tinygrad / test_matmul

Function test_matmul

extra/gemm/rdna4_asm_matmul.py:203–245  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

201N = getenv("N", 4096)
202
203def test_matmul():
204 dev = Device[Device.DEFAULT]
205 arch = getattr(dev.renderer, 'arch', 'gfx1200')
206 print(f"Device arch: {arch}")
207 insts = build_kernel(N, arch)
208
209 rng = np.random.default_rng(42)
210 a = Tensor(rng.random((N, N), dtype=np.float32).astype(np.float16))
211 b = Tensor(rng.random((N, N), dtype=np.float32).astype(np.float16))
212 c = Tensor.empty(N, N, dtype=dtypes.half)
213 Tensor.realize(a, b, c)
214
215 grid, local = (N//BLOCK_N, N//BLOCK_M, 1), (THREADS, 1, 1)
216 print(f"Grid: {grid}, Local: {local}")
217
218 dname = Device.DEFAULT
219 def asm_kernel(A, B, C):
220 gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
221 lidxs = [UOp.special(THREADS, "lidx0")]
222 lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2)), addrspace=AddrSpace.LOCAL), (), 'lds')
223 sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs,
224 arg=KernelInfo(name=colored("kernel","cyan"), estimates=Estimates(ops=N*N*N*2, mem=N*N*2*3)))
225 return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
226
227 c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2]
228 linear = c.schedule_linear()
229
230 ets = []
231 with Context(DEBUG=2):
232 for _ in range(getenv("CNT", 5)):
233 start = GlobalCounters.time_sum_s
234 run_linear(linear)
235 ets.append(GlobalCounters.time_sum_s - start)
236 print(f"REAL TFLOPS {N*N*N*2 / min(ets) * 1e-12:.2f}")
237
238 if getenv("VERIFY", 1):
239 GlobalCounters.reset()
240 c_np = c.float().numpy()
241 a_np, b_np = a.float().numpy(), b.float().numpy()
242 ref = a_np @ b_np
243 err = np.sqrt(np.mean((c_np - ref)**2)) / np.sqrt(np.mean(ref**2))
244 print(f"relative RMSE {err:.6f}")
245 if err != err or err > 0.05: raise RuntimeError(f"matmul is wrong! RMSE={err}")
246
247if __name__ == "__main__":
248 test_matmul()

Callers 1

Calls 15

TensorClass · 0.90
ContextClass · 0.90
getenvFunction · 0.90
run_linearFunction · 0.90
realizeMethod · 0.80
schedule_linearMethod · 0.80
appendMethod · 0.80
floatMethod · 0.80
sqrtMethod · 0.80
meanMethod · 0.80
build_kernelFunction · 0.70
emptyMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…