Check the runtime for the given dtype and case.
(dtype: str, case: MathCase, executable: Executable)
| 265 | |
| 266 | |
| 267 | def check_runtime(dtype: str, case: MathCase, executable: Executable): |
| 268 | """Check the runtime for the given dtype and case.""" |
| 269 | dev = tvm.cuda(0) |
| 270 | |
| 271 | np_inputs = make_numpy_inputs(dtype, case) |
| 272 | expected = case.np_ref(*[arr.astype(dtype) for arr in np_inputs]).astype(dtype) |
| 273 | |
| 274 | tvm_inputs = [tvm.runtime.tensor(arr, device=dev) for arr in np_inputs] |
| 275 | output = tvm.runtime.empty((VECTOR_N_INPUTS,), dtype, dev) |
| 276 | |
| 277 | executable(*tvm_inputs, output) |
| 278 | dev.sync() |
| 279 | |
| 280 | actual = output.numpy() |
| 281 | |
| 282 | np.testing.assert_allclose(actual, expected, rtol=case.rtol, atol=case.atol) |
| 283 | |
| 284 | |
| 285 | @pytest.mark.parametrize("enable_fast_math", [False, True], ids=["default", "fast_math"]) |
no test coverage detected
searching dependent graphs…