(self,
hidden_states,
lora_layer_params=None,
all_reduce_params: Optional[AllReduceParams] = None)
| 229 | gather_output=False) |
| 230 | |
| 231 | def forward(self, |
| 232 | hidden_states, |
| 233 | lora_layer_params=None, |
| 234 | all_reduce_params: Optional[AllReduceParams] = None): |
| 235 | if lora_layer_params is not None: |
| 236 | assert lora_layer_params.get_runtime_params( |
| 237 | 0, "mlp_gate_up" |
| 238 | ) is None, f"LoRA module 'mlp_gate_up' is not supported in {self}" |
| 239 | |
| 240 | mlp_fc_lora_params = None |
| 241 | if lora_layer_params is not None: |
| 242 | mlp_fc_lora_params = lora_layer_params.get_runtime_params( |
| 243 | 0, "mlp_h_to_4h") |
| 244 | |
| 245 | mlp_gate_lora_params = None |
| 246 | if lora_layer_params is not None: |
| 247 | mlp_gate_lora_params = lora_layer_params.get_runtime_params( |
| 248 | 0, "mlp_gate") |
| 249 | |
| 250 | mlp_proj_lora_params = None |
| 251 | if lora_layer_params is not None: |
| 252 | mlp_proj_lora_params = lora_layer_params.get_runtime_params( |
| 253 | 0, "mlp_4h_to_h") |
| 254 | |
| 255 | inter = self.fc(hidden_states, mlp_fc_lora_params) |
| 256 | inter = ACT2FN[self.hidden_act](inter) |
| 257 | gate = self.gate(hidden_states, mlp_gate_lora_params) |
| 258 | intermediate = inter * gate |
| 259 | if self.inner_layernorm is not None: |
| 260 | intermediate = self.inner_layernorm(intermediate) |
| 261 | output = self.proj(intermediate, |
| 262 | lora_runtime_params=mlp_proj_lora_params, |
| 263 | all_reduce_params=all_reduce_params) |
| 264 | return output |
| 265 | |
| 266 | |
| 267 | class FusedGatedMLP(Module): |
nothing calls this directly
no test coverage detected