(hidden_states, lora, fused_gate_up_lora, lora_layer_params)
| 31 | |
| 32 | |
| 33 | def fc_gate_lora(hidden_states, lora, fused_gate_up_lora, lora_layer_params): |
| 34 | if lora_layer_params is not None: |
| 35 | mlp_fc_lora_params = lora_layer_params.get_runtime_params( |
| 36 | 0, "mlp_h_to_4h") |
| 37 | mlp_gate_lora_params = lora_layer_params.get_runtime_params( |
| 38 | 0, "mlp_gate") |
| 39 | mlp_gate_up_lora_params = lora_layer_params.get_runtime_params( |
| 40 | 0, "mlp_gate_up") |
| 41 | |
| 42 | if mlp_gate_up_lora_params is not None: |
| 43 | assert fused_gate_up_lora is not None |
| 44 | mlp_gate_up_lora = fused_gate_up_lora(hidden_states, |
| 45 | mlp_gate_up_lora_params) |
| 46 | return mlp_gate_up_lora |
| 47 | |
| 48 | elif mlp_fc_lora_params is not None and mlp_gate_lora_params is not None: |
| 49 | mlp_in_lora_params = LoraRuntimeParams( |
| 50 | lora_ranks=[ |
| 51 | mlp_fc_lora_params.lora_ranks[0], |
| 52 | mlp_gate_lora_params.lora_ranks[0] |
| 53 | ], |
| 54 | lora_weights_pointers=[ |
| 55 | mlp_fc_lora_params.lora_weights_pointers[0], |
| 56 | mlp_gate_lora_params.lora_weights_pointers[0] |
| 57 | ], |
| 58 | host_request_types=mlp_fc_lora_params.host_request_types, |
| 59 | host_context_lengths=mlp_fc_lora_params.host_context_lengths) |
| 60 | |
| 61 | mlp_fc_lora, mlp_gate_lora = lora(hidden_states, mlp_in_lora_params) |
| 62 | mlp_in_result = concat([mlp_gate_lora, mlp_fc_lora], |
| 63 | dim=mlp_fc_lora.rank() - 1) |
| 64 | return mlp_in_result |
| 65 | return None |
| 66 | |
| 67 | |
| 68 | def fc_gate_dora(hidden_states, dora, fused_gate_up_dora, lora_layer_params): |
no test coverage detected