(self, hidden_states, token_selected_experts,
token_final_scales, lora_layer_params, side_stream_id)
| 1288 | ) |
| 1289 | |
| 1290 | def forward_experts(self, hidden_states, token_selected_experts, |
| 1291 | token_final_scales, lora_layer_params, side_stream_id): |
| 1292 | assert side_stream_id == SideStreamIDType.disable, "MoeOOTB does not support using side stream" |
| 1293 | # TODO: https://nvbugspro.nvidia.com/bug/4781396 after this nvbug is fixed, we will remove this check. |
| 1294 | if lora_layer_params is not None: |
| 1295 | for module in ["mlp_h_to_4h", "mlp_4h_to_h", "mlp_gate"]: |
| 1296 | if lora_layer_params.get_runtime_params(0, module) is not None: |
| 1297 | raise RuntimeError( |
| 1298 | f"MoE OOTB does not support {module} LoRA module, please enable MoE plugin" |
| 1299 | ) |
| 1300 | |
| 1301 | topk_indices = token_selected_experts |
| 1302 | topk_values = token_final_scales |
| 1303 | |
| 1304 | hidden_size = shape(hidden_states, -1) |
| 1305 | # [B*sq, hidden] |
| 1306 | inputs_merged = hidden_states.view(concat([-1, hidden_size])) |
| 1307 | flat_topk_indices = topk_indices.view( |
| 1308 | concat([-1, shape(topk_indices, -1)])) |
| 1309 | flat_topk_values = topk_values.view(concat([-1, |
| 1310 | shape(topk_values, -1)])) |
| 1311 | |
| 1312 | # Create output space |
| 1313 | zero_buffer = inputs_merged * 0.0 |
| 1314 | output = zero_buffer |
| 1315 | |
| 1316 | expert_indices_stack = [] |
| 1317 | indices_stack = [] |
| 1318 | # When topk indices are equal to expert index, the expert will inference the tokens. |
| 1319 | # Bundle all indices and experts index, then do mask once. |
| 1320 | for i, expert in enumerate(self.experts): |
| 1321 | if self.mapping.has_moe_ep(): |
| 1322 | index = i + self.experts_per_node * self.mapping.moe_ep_rank |
| 1323 | else: |
| 1324 | index = i |
| 1325 | expert_indices_stack.append( |
| 1326 | flat_topk_indices.view(concat([1, shape(flat_topk_indices)]))) |
| 1327 | |
| 1328 | indices_stack.append(constant(int32_array(index))) |
| 1329 | |
| 1330 | all_expert_indices = concat(expert_indices_stack, dim=0) |
| 1331 | indices = expand( |
| 1332 | concat(indices_stack).view(concat([len(self.experts), 1, 1])), |
| 1333 | shape(all_expert_indices)) |
| 1334 | |
| 1335 | # Create all experts mask |
| 1336 | all_expert_mask = all_expert_indices == indices |
| 1337 | |
| 1338 | experts_weights = cast( |
| 1339 | sum(flat_topk_values * |
| 1340 | cast(all_expert_mask, flat_topk_values.dtype), |
| 1341 | dim=-1, |
| 1342 | keepdim=True), self.dtype) |
| 1343 | |
| 1344 | all_expert_mask = cast( |
| 1345 | sum(cast(all_expert_mask, flat_topk_values.dtype), |
| 1346 | dim=-1, |
| 1347 | keepdim=True), 'bool') |
nothing calls this directly
no test coverage detected