(self, logits)
| 898 | return topk_indices, topk_values |
| 899 | |
| 900 | def renormalize(self, logits): |
| 901 | # Get top-k experts and renormalize their scores |
| 902 | token_scores, token_selected_experts = topk(cast(logits, trt.float32), |
| 903 | k=self.moe_config.top_k, |
| 904 | dim=-1) |
| 905 | token_final_scales = softmax(token_scores, dim=-1) |
| 906 | return token_selected_experts, token_final_scales |
| 907 | |
| 908 | def group_limited_greedy(self, logits): |
| 909 | n_group = self.moe_config.device_limited_n_group |