(self)
| 506 | self.force = True |
| 507 | |
| 508 | def build_extensions(self) -> None: |
| 509 | compiler_name, compiler_version = self._check_abi() |
| 510 | |
| 511 | cuda_ext = False |
| 512 | extension_iter = iter(self.extensions) |
| 513 | extension = next(extension_iter, None) |
| 514 | while not cuda_ext and extension: |
| 515 | for source in extension.sources: |
| 516 | _, ext = os.path.splitext(source) |
| 517 | if ext == '.cu': |
| 518 | cuda_ext = True |
| 519 | break |
| 520 | extension = next(extension_iter, None) |
| 521 | |
| 522 | if cuda_ext and not IS_HIP_EXTENSION: |
| 523 | _check_cuda_version(compiler_name, compiler_version) |
| 524 | |
| 525 | for extension in self.extensions: |
| 526 | # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when |
| 527 | # extra_compile_args is a dict. Otherwise, default torch flags do |
| 528 | # not get passed. Necessary when only one of 'cxx' and 'nvcc' is |
| 529 | # passed to extra_compile_args in CUDAExtension, i.e. |
| 530 | # CUDAExtension(..., extra_compile_args={'cxx': [...]}) |
| 531 | # or |
| 532 | # CUDAExtension(..., extra_compile_args={'nvcc': [...]}) |
| 533 | if isinstance(extension.extra_compile_args, dict): |
| 534 | for ext in ['cxx', 'nvcc']: |
| 535 | if ext not in extension.extra_compile_args: |
| 536 | extension.extra_compile_args[ext] = [] |
| 537 | |
| 538 | self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H') |
| 539 | # See note [Pybind11 ABI constants] |
| 540 | for name in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: |
| 541 | val = getattr(torch._C, f"_PYBIND11_{name}") |
| 542 | if val is not None and not IS_WINDOWS: |
| 543 | self._add_compile_flag(extension, f'-DPYBIND11_{name}="{val}"') |
| 544 | self._define_torch_extension_name(extension) |
| 545 | self._add_gnu_cpp_abi_flag(extension) |
| 546 | |
| 547 | if 'nvcc_dlink' in extension.extra_compile_args: |
| 548 | assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}." |
| 549 | |
| 550 | # Register .cu, .cuh, .hip, and .mm as valid source extensions. |
| 551 | self.compiler.src_extensions += ['.cu', '.cuh', '.hip'] |
| 552 | if torch.backends.mps.is_built(): |
| 553 | self.compiler.src_extensions += ['.mm'] |
| 554 | # Save the original _compile method for later. |
| 555 | if self.compiler.compiler_type == 'msvc': |
| 556 | self.compiler._cpp_extensions += ['.cu', '.cuh'] |
| 557 | original_compile = self.compiler.compile |
| 558 | original_spawn = self.compiler.spawn |
| 559 | else: |
| 560 | original_compile = self.compiler._compile |
| 561 | |
| 562 | def append_std17_if_no_std_present(cflags) -> None: |
| 563 | # NVCC does not allow multiple -std to be passed, so we avoid |
| 564 | # overriding the option if the user explicitly passed it. |
| 565 | cpp_format_prefix = '/{}:' if self.compiler.compiler_type == 'msvc' else '-{}=' |
nothing calls this directly
no test coverage detected