MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / fc_gate_lora

Function fc_gate_lora

tensorrt_llm/layers/mlp.py:33–65  ·  view source on GitHub ↗
(hidden_states, lora, fused_gate_up_lora, lora_layer_params)

Source from the content-addressed store, hash-verified

31
32
33def fc_gate_lora(hidden_states, lora, fused_gate_up_lora, lora_layer_params):
34 if lora_layer_params is not None:
35 mlp_fc_lora_params = lora_layer_params.get_runtime_params(
36 0, "mlp_h_to_4h")
37 mlp_gate_lora_params = lora_layer_params.get_runtime_params(
38 0, "mlp_gate")
39 mlp_gate_up_lora_params = lora_layer_params.get_runtime_params(
40 0, "mlp_gate_up")
41
42 if mlp_gate_up_lora_params is not None:
43 assert fused_gate_up_lora is not None
44 mlp_gate_up_lora = fused_gate_up_lora(hidden_states,
45 mlp_gate_up_lora_params)
46 return mlp_gate_up_lora
47
48 elif mlp_fc_lora_params is not None and mlp_gate_lora_params is not None:
49 mlp_in_lora_params = LoraRuntimeParams(
50 lora_ranks=[
51 mlp_fc_lora_params.lora_ranks[0],
52 mlp_gate_lora_params.lora_ranks[0]
53 ],
54 lora_weights_pointers=[
55 mlp_fc_lora_params.lora_weights_pointers[0],
56 mlp_gate_lora_params.lora_weights_pointers[0]
57 ],
58 host_request_types=mlp_fc_lora_params.host_request_types,
59 host_context_lengths=mlp_fc_lora_params.host_context_lengths)
60
61 mlp_fc_lora, mlp_gate_lora = lora(hidden_states, mlp_in_lora_params)
62 mlp_in_result = concat([mlp_gate_lora, mlp_fc_lora],
63 dim=mlp_fc_lora.rank() - 1)
64 return mlp_in_result
65 return None
66
67
68def fc_gate_dora(hidden_states, dora, fused_gate_up_dora, lora_layer_params):

Callers 3

forwardMethod · 0.85
fc_gate_pluginMethod · 0.85
fc_gateMethod · 0.85

Calls 4

LoraRuntimeParamsClass · 0.85
concatFunction · 0.85
get_runtime_paramsMethod · 0.80
rankMethod · 0.45

Tested by

no test coverage detected