Args: nvfp4_allowed_backends: List of backends to consider for NVFP4 GEMM auto-selection. Default (via config): ['cutlass', 'cublaslt', 'cuda_core'] - excludes cutedsl for faster build. Add 'cutedsl' for extreme performance at the cost of longer b
(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
mapping: Optional[Mapping] = None,
tensor_parallel_mode: Optional[TensorParallelMode] = None,
gather_output: bool = False, # COLUMN parallel only
quant_config: Optional[QuantConfig] = None,
weights_loading_config: Optional[WeightsLoadingConfig] = None,
reduce_output: bool = True, # ROW parallel only
skip_create_weights_in_init: bool = False,
use_custom_cublas_mm: bool = False,
lora: Optional[LoraLayer] = None,
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
force_dynamic_quantization: bool = False,
use_cute_dsl_blockscaling_mm: bool = False,
disable_deep_gemm: bool = False,
fused_weight_shard_indices_mapping: Optional[dict] = None,
nvfp4_allowed_backends: Optional[List[str]] = None,
enable_gemm_allreduce_fusion: bool = True,
)
| 2069 | class Linear(nn.Module): |
| 2070 | |
| 2071 | def __init__( |
| 2072 | self, |
| 2073 | in_features: int, |
| 2074 | out_features: int, |
| 2075 | bias: bool = True, |
| 2076 | dtype: torch.dtype = None, |
| 2077 | mapping: Optional[Mapping] = None, |
| 2078 | tensor_parallel_mode: Optional[TensorParallelMode] = None, |
| 2079 | gather_output: bool = False, # COLUMN parallel only |
| 2080 | quant_config: Optional[QuantConfig] = None, |
| 2081 | weights_loading_config: Optional[WeightsLoadingConfig] = None, |
| 2082 | reduce_output: bool = True, # ROW parallel only |
| 2083 | skip_create_weights_in_init: bool = False, |
| 2084 | use_custom_cublas_mm: bool = False, |
| 2085 | lora: Optional[LoraLayer] = None, |
| 2086 | allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO, |
| 2087 | force_dynamic_quantization: bool = False, |
| 2088 | use_cute_dsl_blockscaling_mm: bool = False, |
| 2089 | disable_deep_gemm: bool = False, |
| 2090 | fused_weight_shard_indices_mapping: Optional[dict] = None, |
| 2091 | nvfp4_allowed_backends: Optional[List[str]] = None, |
| 2092 | enable_gemm_allreduce_fusion: bool = True, |
| 2093 | ): |
| 2094 | """ |
| 2095 | Args: |
| 2096 | nvfp4_allowed_backends: List of backends to consider for NVFP4 GEMM auto-selection. |
| 2097 | Default (via config): ['cutlass', 'cublaslt', 'cuda_core'] - excludes cutedsl for faster build. |
| 2098 | Add 'cutedsl' for extreme performance at the cost of longer build time. |
| 2099 | Valid backends: 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'. |
| 2100 | Configure via nvfp4_gemm_config.allowed_backends in extra_llm_api_options.yaml. |
| 2101 | """ |
| 2102 | from ..distributed import AllReduce |
| 2103 | |
| 2104 | super().__init__() |
| 2105 | self.has_bias = bias |
| 2106 | self.dtype = dtype |
| 2107 | self.mapping = mapping or Mapping() |
| 2108 | # could be modified later |
| 2109 | self.quant_config = quant_config |
| 2110 | self.weights_loading_config = weights_loading_config or WeightsLoadingConfig( |
| 2111 | ) |
| 2112 | self.tp_size = self.mapping.tp_size |
| 2113 | self.tp_rank = self.mapping.tp_rank |
| 2114 | self.tp_mode = tensor_parallel_mode |
| 2115 | self.gather_output = gather_output |
| 2116 | self.force_dynamic_quantization = force_dynamic_quantization |
| 2117 | self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm |
| 2118 | self.disable_deep_gemm = disable_deep_gemm |
| 2119 | self.fused_weight_shard_indices_mapping = fused_weight_shard_indices_mapping |
| 2120 | |
| 2121 | # Store NVFP4 GEMM allowed backends configuration |
| 2122 | # Read from model_extra_attrs if not explicitly provided (allows config via llm_api_options) |
| 2123 | if nvfp4_allowed_backends is None: |
| 2124 | model_attrs = get_model_extra_attrs() |
| 2125 | if model_attrs: |
| 2126 | nvfp4_allowed_backends = model_attrs.get( |
| 2127 | 'nvfp4_gemm_allowed_backends') |
| 2128 | # Default: exclude cutedsl for faster build time |
nothing calls this directly
no test coverage detected