MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / __init__

Method __init__

tensorrt_llm/_torch/modules/linear.py:2071–2190  ·  view source on GitHub ↗

Args: nvfp4_allowed_backends: List of backends to consider for NVFP4 GEMM auto-selection. Default (via config): ['cutlass', 'cublaslt', 'cuda_core'] - excludes cutedsl for faster build. Add 'cutedsl' for extreme performance at the cost of longer b

(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        dtype: torch.dtype = None,
        mapping: Optional[Mapping] = None,
        tensor_parallel_mode: Optional[TensorParallelMode] = None,
        gather_output: bool = False,  # COLUMN parallel only
        quant_config: Optional[QuantConfig] = None,
        weights_loading_config: Optional[WeightsLoadingConfig] = None,
        reduce_output: bool = True,  # ROW parallel only
        skip_create_weights_in_init: bool = False,
        use_custom_cublas_mm: bool = False,
        lora: Optional[LoraLayer] = None,
        allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
        force_dynamic_quantization: bool = False,
        use_cute_dsl_blockscaling_mm: bool = False,
        disable_deep_gemm: bool = False,
        fused_weight_shard_indices_mapping: Optional[dict] = None,
        nvfp4_allowed_backends: Optional[List[str]] = None,
        enable_gemm_allreduce_fusion: bool = True,
    )

Source from the content-addressed store, hash-verified

2069class Linear(nn.Module):
2070
2071 def __init__(
2072 self,
2073 in_features: int,
2074 out_features: int,
2075 bias: bool = True,
2076 dtype: torch.dtype = None,
2077 mapping: Optional[Mapping] = None,
2078 tensor_parallel_mode: Optional[TensorParallelMode] = None,
2079 gather_output: bool = False, # COLUMN parallel only
2080 quant_config: Optional[QuantConfig] = None,
2081 weights_loading_config: Optional[WeightsLoadingConfig] = None,
2082 reduce_output: bool = True, # ROW parallel only
2083 skip_create_weights_in_init: bool = False,
2084 use_custom_cublas_mm: bool = False,
2085 lora: Optional[LoraLayer] = None,
2086 allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
2087 force_dynamic_quantization: bool = False,
2088 use_cute_dsl_blockscaling_mm: bool = False,
2089 disable_deep_gemm: bool = False,
2090 fused_weight_shard_indices_mapping: Optional[dict] = None,
2091 nvfp4_allowed_backends: Optional[List[str]] = None,
2092 enable_gemm_allreduce_fusion: bool = True,
2093 ):
2094 """
2095 Args:
2096 nvfp4_allowed_backends: List of backends to consider for NVFP4 GEMM auto-selection.
2097 Default (via config): ['cutlass', 'cublaslt', 'cuda_core'] - excludes cutedsl for faster build.
2098 Add 'cutedsl' for extreme performance at the cost of longer build time.
2099 Valid backends: 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'.
2100 Configure via nvfp4_gemm_config.allowed_backends in extra_llm_api_options.yaml.
2101 """
2102 from ..distributed import AllReduce
2103
2104 super().__init__()
2105 self.has_bias = bias
2106 self.dtype = dtype
2107 self.mapping = mapping or Mapping()
2108 # could be modified later
2109 self.quant_config = quant_config
2110 self.weights_loading_config = weights_loading_config or WeightsLoadingConfig(
2111 )
2112 self.tp_size = self.mapping.tp_size
2113 self.tp_rank = self.mapping.tp_rank
2114 self.tp_mode = tensor_parallel_mode
2115 self.gather_output = gather_output
2116 self.force_dynamic_quantization = force_dynamic_quantization
2117 self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm
2118 self.disable_deep_gemm = disable_deep_gemm
2119 self.fused_weight_shard_indices_mapping = fused_weight_shard_indices_mapping
2120
2121 # Store NVFP4 GEMM allowed backends configuration
2122 # Read from model_extra_attrs if not explicitly provided (allows config via llm_api_options)
2123 if nvfp4_allowed_backends is None:
2124 model_attrs = get_model_extra_attrs()
2125 if model_attrs:
2126 nvfp4_allowed_backends = model_attrs.get(
2127 'nvfp4_gemm_allowed_backends')
2128 # Default: exclude cutedsl for faster build time

Callers

nothing calls this directly

Calls 10

create_weightsMethod · 0.95
MappingClass · 0.90
mpi_disabledFunction · 0.90
get_model_extra_attrsFunction · 0.85
AllReduceClass · 0.85
get_sm_versionFunction · 0.50
getMethod · 0.45
has_nvfp4Method · 0.45
deviceMethod · 0.45

Tested by

no test coverage detected