(hidden_states, dora, fused_gate_up_dora, lora_layer_params)
| 66 | |
| 67 | |
| 68 | def fc_gate_dora(hidden_states, dora, fused_gate_up_dora, lora_layer_params): |
| 69 | if lora_layer_params is not None: |
| 70 | mlp_fc_lora_params = lora_layer_params.get_runtime_params( |
| 71 | 0, "mlp_h_to_4h") |
| 72 | mlp_gate_lora_params = lora_layer_params.get_runtime_params( |
| 73 | 0, "mlp_gate") |
| 74 | mlp_gate_up_lora_params = lora_layer_params.get_runtime_params( |
| 75 | 0, "mlp_gate_up") |
| 76 | |
| 77 | if mlp_gate_up_lora_params is not None: |
| 78 | assert fused_gate_up_dora is not None |
| 79 | return fused_gate_up_dora(hidden_states, mlp_gate_up_lora_params) |
| 80 | |
| 81 | if mlp_fc_lora_params is not None and mlp_gate_lora_params is not None: |
| 82 | mlp_in_lora_params = LoraRuntimeParams( |
| 83 | lora_ranks=[ |
| 84 | mlp_fc_lora_params.lora_ranks[0], |
| 85 | mlp_gate_lora_params.lora_ranks[0] |
| 86 | ], |
| 87 | lora_weights_pointers=[ |
| 88 | mlp_fc_lora_params.lora_weights_pointers[0], |
| 89 | mlp_gate_lora_params.lora_weights_pointers[0] |
| 90 | ], |
| 91 | host_request_types=mlp_fc_lora_params.host_request_types, |
| 92 | host_context_lengths=mlp_fc_lora_params.host_context_lengths) |
| 93 | |
| 94 | return dora(hidden_states, mlp_in_lora_params) |
| 95 | return None |
| 96 | |
| 97 | |
| 98 | class MLP(Module): |
no test coverage detected