(self, N=getenv("GEMM_N", 64))
| 41 | self.assertEqual(out.item(), N) |
| 42 | |
| 43 | def test_gemm(self, N=getenv("GEMM_N", 64)): |
| 44 | a = Tensor.ones(N,N).contiguous() |
| 45 | b = Tensor.eye(N).contiguous() |
| 46 | lst = (out:=a@b).tolist() |
| 47 | for y in range(N): |
| 48 | for x in range(N): |
| 49 | self.assertEqual(lst[y][x], 1.0, msg=f"mismatch at ({y},{x})") |
| 50 | self.assertEqual(out.dtype, dtypes.float) |
| 51 | |
| 52 | def test_gemv(self, N=getenv("GEMV_N", 64), out_dtype=dtypes.float): |
| 53 | a = Tensor.ones(1,N).contiguous() |
no test coverage detected