MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / _moe_plugin

Function _moe_plugin

tensorrt_llm/layers/moe.py:145–405  ·  view source on GitHub ↗
(moe_config,
                hidden_states,
                hidden_states_raw,
                token_selected_experts,
                token_final_scales,
                expert_weights_1,
                expert_weights_2,
                expert_bias_1,
                expert_bias_2,
                expert_scale_1,
                expert_scale_2,
                expert_scale_3,
                expert_scale_4,
                expert_scale_5,
                expert_scale_6,
                groupwise_quant_params,
                hidden_size,
                ffn_hidden_size,
                act_fn,
                dtype,
                weight_dtype,
                output_dtype,
                lora_params: LoraParams,
                lora_max_low_rank,
                quant_mode=QuantMode(0),
                tp_size=1,
                ep_size=1,
                tp_rank=0,
                ep_rank=0,
                side_stream_id=SideStreamIDType.disable)

Source from the content-addressed store, hash-verified

143
144
145def _moe_plugin(moe_config,
146 hidden_states,
147 hidden_states_raw,
148 token_selected_experts,
149 token_final_scales,
150 expert_weights_1,
151 expert_weights_2,
152 expert_bias_1,
153 expert_bias_2,
154 expert_scale_1,
155 expert_scale_2,
156 expert_scale_3,
157 expert_scale_4,
158 expert_scale_5,
159 expert_scale_6,
160 groupwise_quant_params,
161 hidden_size,
162 ffn_hidden_size,
163 act_fn,
164 dtype,
165 weight_dtype,
166 output_dtype,
167 lora_params: LoraParams,
168 lora_max_low_rank,
169 quant_mode=QuantMode(0),
170 tp_size=1,
171 ep_size=1,
172 tp_rank=0,
173 ep_rank=0,
174 side_stream_id=SideStreamIDType.disable):
175 if isinstance(dtype, str):
176 dtype = str_dtype_to_trt(dtype)
177
178 if isinstance(weight_dtype, str):
179 weight_dtype = str_dtype_to_trt(weight_dtype)
180
181 if isinstance(output_dtype, str):
182 output_dtype = str_dtype_to_trt(output_dtype)
183
184 def from_parameter(x):
185 if isinstance(x, Parameter):
186 return x.value
187 return x
188
189 expert_weights_1 = from_parameter(expert_weights_1)
190 expert_weights_2 = from_parameter(expert_weights_2)
191 expert_bias_1 = from_parameter(expert_bias_1)
192 expert_bias_2 = from_parameter(expert_bias_2)
193 expert_scale_1 = from_parameter(expert_scale_1)
194 expert_scale_2 = from_parameter(expert_scale_2)
195 expert_scale_3 = from_parameter(expert_scale_3)
196 expert_scale_4 = from_parameter(expert_scale_4)
197 expert_scale_5 = from_parameter(expert_scale_5)
198 expert_scale_6 = from_parameter(expert_scale_6)
199
200 # Create the plugin with our required state
201 num_experts = moe_config.num_experts
202 p_remove_input_padding = trt.PluginField(

Callers 1

forward_expertsMethod · 0.85

Calls 15

str_dtype_to_trtFunction · 0.90
QuantModeClass · 0.85
from_parameterFunction · 0.85
default_netFunction · 0.85
default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
int32Method · 0.80
create_pluginMethod · 0.80
is_weight_onlyMethod · 0.80
has_per_group_scalingMethod · 0.80
get_runtime_paramsMethod · 0.80

Tested by

no test coverage detected