(self, in_features: int, out_features: int,
experts_per_node: int, quant_mode: QuantMode,
groupwise_quant_algo: int, group_size: int,
dtype: Union[str,
trt.DataType], weight_dtype: Union[str,
trt.DataType],
has_bias: bool, wrapper_tllm_to_externel_key_dict: dict,
tp_size: int, tp_dim: int)
| 423 | class MOEWeightWrapper(Module): |
| 424 | |
| 425 | def __init__(self, in_features: int, out_features: int, |
| 426 | experts_per_node: int, quant_mode: QuantMode, |
| 427 | groupwise_quant_algo: int, group_size: int, |
| 428 | dtype: Union[str, |
| 429 | trt.DataType], weight_dtype: Union[str, |
| 430 | trt.DataType], |
| 431 | has_bias: bool, wrapper_tllm_to_externel_key_dict: dict, |
| 432 | tp_size: int, tp_dim: int): |
| 433 | super().__init__() |
| 434 | self.quant_mode = quant_mode |
| 435 | self.groupwise_quant_algo = groupwise_quant_algo |
| 436 | self.group_size = group_size |
| 437 | self.expert_shape = (experts_per_node, out_features, in_features) |
| 438 | self.dtype = dtype |
| 439 | self.weight_dtype = weight_dtype |
| 440 | self.has_bias = has_bias |
| 441 | self.tllm_to_externel_key_dict = wrapper_tllm_to_externel_key_dict |
| 442 | self.tp_size = tp_size |
| 443 | self.tp_dim = 1 - tp_dim if quant_mode.has_per_group_scaling( |
| 444 | ) else tp_dim |
| 445 | self.is_padded = False |
| 446 | |
| 447 | if quant_mode.is_weight_only( |
| 448 | ) and not quant_mode.has_per_group_scaling(): |
| 449 | bytes_per_col_scale = 2 if quant_mode.is_int4_weight_only() else 1 |
| 450 | # We use a different shape here because the quantized weights have their own layout |
| 451 | self.expert_shape = (experts_per_node, in_features, |
| 452 | out_features // bytes_per_col_scale) |
| 453 | self.per_channel_scale = Parameter(shape=(experts_per_node, |
| 454 | out_features), |
| 455 | dtype=dtype) |
| 456 | else: |
| 457 | self.register_parameter('per_channel_scale', None) |
| 458 | |
| 459 | if quant_mode.has_nvfp4(): |
| 460 | self.expert_shape = (experts_per_node, out_features, in_features) |
| 461 | weight_dtype = trt.fp4 |
| 462 | |
| 463 | if not quant_mode.has_per_group_scaling(): |
| 464 | self.weight = Parameter(shape=self.expert_shape, |
| 465 | dtype=weight_dtype, |
| 466 | prefer_managed=True) |
| 467 | |
| 468 | if has_bias: |
| 469 | self.bias = Parameter(shape=(experts_per_node, out_features), |
| 470 | dtype=dtype) |
| 471 | else: |
| 472 | self.register_parameter('bias', None) |
| 473 | |
| 474 | self.scaling_vector_size = 16 |
| 475 | if quant_mode.has_fp8_qdq(): |
| 476 | self.activation_scaling_factor = Parameter(shape=(1, ), |
| 477 | dtype=trt.float32) |
| 478 | self.weights_scaling_factor = Parameter(shape=(experts_per_node, 1), |
| 479 | dtype=trt.float32) |
| 480 | elif quant_mode.has_nvfp4(): |
| 481 | self.weights_block_scaling_factor_interleaved = Parameter( |
| 482 | shape=(experts_per_node, out_features, |
nothing calls this directly
no test coverage detected