(valid_cuda)
| 25 | |
| 26 | |
| 27 | def skip_on_cuda(valid_cuda): |
| 28 | split_version = lambda x: map(int, x.split('.')[:2]) |
| 29 | if get_accelerator().device_name() == 'cuda': |
| 30 | CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version']) |
| 31 | CUDA_VERSION = (CUDA_MAJOR * 10) + CUDA_MINOR |
| 32 | if valid_cuda.count(CUDA_VERSION) == 0: |
| 33 | pytest.skip(f"requires cuda versions {valid_cuda}") |
| 34 | else: |
| 35 | assert is_current_accelerator_supported() |
| 36 | return |
| 37 | |
| 38 | |
| 39 | def bf16_required_version_check(accelerator_check=True): |
searching dependent graphs…