(self,
*,
hidden_size: int,
intermediate_size: int,
bias: bool,
activation: Callable[[torch.Tensor], torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
config: Optional[ModelConfig] = None,
layer_idx: Optional[int] = None,
reduce_output: bool = True,
overridden_tp_size: Optional[int] = None)
| 14 | class MLP(nn.Module): |
| 15 | |
| 16 | def __init__(self, |
| 17 | *, |
| 18 | hidden_size: int, |
| 19 | intermediate_size: int, |
| 20 | bias: bool, |
| 21 | activation: Callable[[torch.Tensor], torch.Tensor] = None, |
| 22 | dtype: Optional[torch.dtype] = None, |
| 23 | config: Optional[ModelConfig] = None, |
| 24 | layer_idx: Optional[int] = None, |
| 25 | reduce_output: bool = True, |
| 26 | overridden_tp_size: Optional[int] = None): |
| 27 | |
| 28 | super().__init__() |
| 29 | self.layer_idx = layer_idx |
| 30 | self.hidden_size = hidden_size |
| 31 | self.intermediate_size = intermediate_size |
| 32 | self.activation = activation |
| 33 | |
| 34 | config = config or ModelConfig() |
| 35 | self.mapping = config.mapping |
| 36 | if overridden_tp_size is not None: |
| 37 | assert config.mapping.tp_size % overridden_tp_size == 0 |
| 38 | tp_size = overridden_tp_size |
| 39 | # "Misuse" pp_size here to perform all-reduce within smaller groups |
| 40 | pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size |
| 41 | mapping = Mapping( |
| 42 | world_size=tp_size * pp_size, |
| 43 | rank=self.mapping.rank, |
| 44 | gpus_per_node=self.mapping.gpus_per_node, |
| 45 | tp_size=tp_size, |
| 46 | pp_size=pp_size, |
| 47 | ) |
| 48 | else: |
| 49 | mapping = config.mapping |
| 50 | |
| 51 | self.up_lora = LoraLayer( |
| 52 | [LoraModuleType.MLP_H_TO_4H], |
| 53 | [self.intermediate_size // config.mapping.tp_size]) |
| 54 | |
| 55 | self.up_proj = Linear( |
| 56 | self.hidden_size, |
| 57 | self.intermediate_size, |
| 58 | bias=bias, |
| 59 | dtype=dtype, |
| 60 | mapping=mapping, |
| 61 | tensor_parallel_mode=TensorParallelMode.COLUMN, |
| 62 | weights_loading_config=WeightsLoadingConfig( |
| 63 | weight_mode=WeightMode.VANILLA), |
| 64 | quant_config=config.get_quant_config(), |
| 65 | skip_create_weights_in_init=config.skip_create_weights_in_init, |
| 66 | lora=self.up_lora, |
| 67 | allreduce_strategy=config.allreduce_strategy, |
| 68 | force_dynamic_quantization=config.force_dynamic_quantization) |
| 69 | |
| 70 | self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], |
| 71 | [self.hidden_size]) |
| 72 | self.down_proj = Linear( |
| 73 | self.intermediate_size, |
nothing calls this directly
no test coverage detected