(moe_config,
hidden_states,
hidden_states_raw,
token_selected_experts,
token_final_scales,
expert_weights_1,
expert_weights_2,
expert_bias_1,
expert_bias_2,
expert_scale_1,
expert_scale_2,
expert_scale_3,
expert_scale_4,
expert_scale_5,
expert_scale_6,
groupwise_quant_params,
hidden_size,
ffn_hidden_size,
act_fn,
dtype,
weight_dtype,
output_dtype,
lora_params: LoraParams,
lora_max_low_rank,
quant_mode=QuantMode(0),
tp_size=1,
ep_size=1,
tp_rank=0,
ep_rank=0,
side_stream_id=SideStreamIDType.disable)
| 143 | |
| 144 | |
| 145 | def _moe_plugin(moe_config, |
| 146 | hidden_states, |
| 147 | hidden_states_raw, |
| 148 | token_selected_experts, |
| 149 | token_final_scales, |
| 150 | expert_weights_1, |
| 151 | expert_weights_2, |
| 152 | expert_bias_1, |
| 153 | expert_bias_2, |
| 154 | expert_scale_1, |
| 155 | expert_scale_2, |
| 156 | expert_scale_3, |
| 157 | expert_scale_4, |
| 158 | expert_scale_5, |
| 159 | expert_scale_6, |
| 160 | groupwise_quant_params, |
| 161 | hidden_size, |
| 162 | ffn_hidden_size, |
| 163 | act_fn, |
| 164 | dtype, |
| 165 | weight_dtype, |
| 166 | output_dtype, |
| 167 | lora_params: LoraParams, |
| 168 | lora_max_low_rank, |
| 169 | quant_mode=QuantMode(0), |
| 170 | tp_size=1, |
| 171 | ep_size=1, |
| 172 | tp_rank=0, |
| 173 | ep_rank=0, |
| 174 | side_stream_id=SideStreamIDType.disable): |
| 175 | if isinstance(dtype, str): |
| 176 | dtype = str_dtype_to_trt(dtype) |
| 177 | |
| 178 | if isinstance(weight_dtype, str): |
| 179 | weight_dtype = str_dtype_to_trt(weight_dtype) |
| 180 | |
| 181 | if isinstance(output_dtype, str): |
| 182 | output_dtype = str_dtype_to_trt(output_dtype) |
| 183 | |
| 184 | def from_parameter(x): |
| 185 | if isinstance(x, Parameter): |
| 186 | return x.value |
| 187 | return x |
| 188 | |
| 189 | expert_weights_1 = from_parameter(expert_weights_1) |
| 190 | expert_weights_2 = from_parameter(expert_weights_2) |
| 191 | expert_bias_1 = from_parameter(expert_bias_1) |
| 192 | expert_bias_2 = from_parameter(expert_bias_2) |
| 193 | expert_scale_1 = from_parameter(expert_scale_1) |
| 194 | expert_scale_2 = from_parameter(expert_scale_2) |
| 195 | expert_scale_3 = from_parameter(expert_scale_3) |
| 196 | expert_scale_4 = from_parameter(expert_scale_4) |
| 197 | expert_scale_5 = from_parameter(expert_scale_5) |
| 198 | expert_scale_6 = from_parameter(expert_scale_6) |
| 199 | |
| 200 | # Create the plugin with our required state |
| 201 | num_experts = moe_config.num_experts |
| 202 | p_remove_input_padding = trt.PluginField( |
no test coverage detected