(self, hidden_states, lora_layer_params=None, gegelu_limit=None)
| 151 | self.dora = None |
| 152 | |
| 153 | def forward(self, hidden_states, lora_layer_params=None, gegelu_limit=None): |
| 154 | if lora_layer_params is not None: |
| 155 | assert lora_layer_params.get_runtime_params( |
| 156 | 0, "mlp_gate_up" |
| 157 | ) is None, f"LoRA module 'mlp_gate_up' is not supported in {self}" |
| 158 | if is_gated_activation(self.hidden_act): |
| 159 | inter = self.fc(hidden_states) |
| 160 | lora_result = fc_gate_lora(hidden_states, self.lora, None, |
| 161 | lora_layer_params) |
| 162 | if lora_result is not None: |
| 163 | inter = inter + lora_result |
| 164 | if self.dora is not None: |
| 165 | inter = fc_gate_dora(inter, self.dora, |
| 166 | self.fused_gate_up_dora, |
| 167 | lora_layer_params) |
| 168 | else: |
| 169 | mlp_fc_lora_params = None |
| 170 | if lora_layer_params is not None: |
| 171 | mlp_fc_lora_params = lora_layer_params.get_runtime_params( |
| 172 | 0, "mlp_h_to_4h") |
| 173 | inter = self.fc(hidden_states, mlp_fc_lora_params) |
| 174 | |
| 175 | mlp_proj_lora_params = None |
| 176 | if lora_layer_params is not None: |
| 177 | mlp_proj_lora_params = lora_layer_params.get_runtime_params( |
| 178 | 0, "mlp_4h_to_h") |
| 179 | |
| 180 | if self.hidden_act == 'gegelu': |
| 181 | inter = ACT2FN[self.hidden_act](inter, gegelu_limit) |
| 182 | else: |
| 183 | inter = ACT2FN[self.hidden_act](inter) |
| 184 | if self.inner_layernorm is not None: |
| 185 | inter = self.inner_layernorm(inter) |
| 186 | output = self.proj(inter, lora_runtime_params=mlp_proj_lora_params) |
| 187 | return output |
| 188 | |
| 189 | |
| 190 | class GatedMLP(MLP): |
no test coverage detected