MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward_experts

Method forward_experts

tensorrt_llm/layers/moe.py:1290–1384  ·  view source on GitHub ↗
(self, hidden_states, token_selected_experts,
                        token_final_scales, lora_layer_params, side_stream_id)

Source from the content-addressed store, hash-verified

1288 )
1289
1290 def forward_experts(self, hidden_states, token_selected_experts,
1291 token_final_scales, lora_layer_params, side_stream_id):
1292 assert side_stream_id == SideStreamIDType.disable, "MoeOOTB does not support using side stream"
1293 # TODO: https://nvbugspro.nvidia.com/bug/4781396 after this nvbug is fixed, we will remove this check.
1294 if lora_layer_params is not None:
1295 for module in ["mlp_h_to_4h", "mlp_4h_to_h", "mlp_gate"]:
1296 if lora_layer_params.get_runtime_params(0, module) is not None:
1297 raise RuntimeError(
1298 f"MoE OOTB does not support {module} LoRA module, please enable MoE plugin"
1299 )
1300
1301 topk_indices = token_selected_experts
1302 topk_values = token_final_scales
1303
1304 hidden_size = shape(hidden_states, -1)
1305 # [B*sq, hidden]
1306 inputs_merged = hidden_states.view(concat([-1, hidden_size]))
1307 flat_topk_indices = topk_indices.view(
1308 concat([-1, shape(topk_indices, -1)]))
1309 flat_topk_values = topk_values.view(concat([-1,
1310 shape(topk_values, -1)]))
1311
1312 # Create output space
1313 zero_buffer = inputs_merged * 0.0
1314 output = zero_buffer
1315
1316 expert_indices_stack = []
1317 indices_stack = []
1318 # When topk indices are equal to expert index, the expert will inference the tokens.
1319 # Bundle all indices and experts index, then do mask once.
1320 for i, expert in enumerate(self.experts):
1321 if self.mapping.has_moe_ep():
1322 index = i + self.experts_per_node * self.mapping.moe_ep_rank
1323 else:
1324 index = i
1325 expert_indices_stack.append(
1326 flat_topk_indices.view(concat([1, shape(flat_topk_indices)])))
1327
1328 indices_stack.append(constant(int32_array(index)))
1329
1330 all_expert_indices = concat(expert_indices_stack, dim=0)
1331 indices = expand(
1332 concat(indices_stack).view(concat([len(self.experts), 1, 1])),
1333 shape(all_expert_indices))
1334
1335 # Create all experts mask
1336 all_expert_mask = all_expert_indices == indices
1337
1338 experts_weights = cast(
1339 sum(flat_topk_values *
1340 cast(all_expert_mask, flat_topk_values.dtype),
1341 dim=-1,
1342 keepdim=True), self.dtype)
1343
1344 all_expert_mask = cast(
1345 sum(cast(all_expert_mask, flat_topk_values.dtype),
1346 dim=-1,
1347 keepdim=True), 'bool')

Callers

nothing calls this directly

Calls 15

concatFunction · 0.85
constantFunction · 0.85
expandFunction · 0.85
castFunction · 0.85
sumFunction · 0.85
repeat_interleaveFunction · 0.85
nonzeroFunction · 0.85
gather_ndFunction · 0.85
scatter_ndFunction · 0.85
get_runtime_paramsMethod · 0.80
has_moe_epMethod · 0.80

Tested by

no test coverage detected