| 84 | |
| 85 | |
| 86 | def get_extensions(): |
| 87 | if not BUILD_CUDA: |
| 88 | return [] |
| 89 | |
| 90 | try: |
| 91 | from torch.utils.cpp_extension import CUDAExtension |
| 92 | |
| 93 | srcs = ["sam2/csrc/connected_components.cu"] |
| 94 | compile_args = { |
| 95 | "cxx": [], |
| 96 | "nvcc": [ |
| 97 | "-DCUDA_HAS_FP16=1", |
| 98 | "-D__CUDA_NO_HALF_OPERATORS__", |
| 99 | "-D__CUDA_NO_HALF_CONVERSIONS__", |
| 100 | "-D__CUDA_NO_HALF2_OPERATORS__", |
| 101 | ], |
| 102 | } |
| 103 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] |
| 104 | except Exception as e: |
| 105 | if BUILD_ALLOW_ERRORS: |
| 106 | print(CUDA_ERROR_MSG.format(e)) |
| 107 | ext_modules = [] |
| 108 | else: |
| 109 | raise e |
| 110 | |
| 111 | return ext_modules |
| 112 | |
| 113 | |
| 114 | try: |