MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / fc_gate_dora

Function fc_gate_dora

tensorrt_llm/layers/mlp.py:68–95  ·  view source on GitHub ↗
(hidden_states, dora, fused_gate_up_dora, lora_layer_params)

Source from the content-addressed store, hash-verified

66
67
68def fc_gate_dora(hidden_states, dora, fused_gate_up_dora, lora_layer_params):
69 if lora_layer_params is not None:
70 mlp_fc_lora_params = lora_layer_params.get_runtime_params(
71 0, "mlp_h_to_4h")
72 mlp_gate_lora_params = lora_layer_params.get_runtime_params(
73 0, "mlp_gate")
74 mlp_gate_up_lora_params = lora_layer_params.get_runtime_params(
75 0, "mlp_gate_up")
76
77 if mlp_gate_up_lora_params is not None:
78 assert fused_gate_up_dora is not None
79 return fused_gate_up_dora(hidden_states, mlp_gate_up_lora_params)
80
81 if mlp_fc_lora_params is not None and mlp_gate_lora_params is not None:
82 mlp_in_lora_params = LoraRuntimeParams(
83 lora_ranks=[
84 mlp_fc_lora_params.lora_ranks[0],
85 mlp_gate_lora_params.lora_ranks[0]
86 ],
87 lora_weights_pointers=[
88 mlp_fc_lora_params.lora_weights_pointers[0],
89 mlp_gate_lora_params.lora_weights_pointers[0]
90 ],
91 host_request_types=mlp_fc_lora_params.host_request_types,
92 host_context_lengths=mlp_fc_lora_params.host_context_lengths)
93
94 return dora(hidden_states, mlp_in_lora_params)
95 return None
96
97
98class MLP(Module):

Callers 2

forwardMethod · 0.85
fc_gateMethod · 0.85

Calls 2

LoraRuntimeParamsClass · 0.85
get_runtime_paramsMethod · 0.80

Tested by

no test coverage detected