True if a CUDA device with Tensor Core support (compute >= 7) exists.
()
| 366 | |
| 367 | @functools.cache |
| 368 | def has_tensorcore() -> bool: |
| 369 | """True if a CUDA device with Tensor Core support (compute >= 7) exists.""" |
| 370 | try: |
| 371 | from tvm.support import nvcc # pylint: disable=import-outside-toplevel |
| 372 | |
| 373 | return has_cuda() and bool(nvcc.have_tensorcore(tvm.cuda().compute_version)) |
| 374 | except Exception: # pylint: disable=broad-except |
| 375 | return False |
| 376 | |
| 377 | |
| 378 | @functools.cache |