Sampling draft tokens. Args: logits: torch.Tensor [num_tokens, vocab_size] Logits produced by the draft model. Returns: draft_tokens: torch.Tensor [batch_size * max_draft_len] Draft tok
(
self,
logits: torch.Tensor,
mapping_lm_head_tp: Mapping = None,
)
| 1093 | return draft_tokens |
| 1094 | |
| 1095 | def draft_sampler( |
| 1096 | self, |
| 1097 | logits: torch.Tensor, |
| 1098 | mapping_lm_head_tp: Mapping = None, |
| 1099 | ): |
| 1100 | ''' |
| 1101 | Sampling draft tokens. |
| 1102 | |
| 1103 | Args: |
| 1104 | logits: torch.Tensor |
| 1105 | [num_tokens, vocab_size] |
| 1106 | Logits produced by the draft model. |
| 1107 | |
| 1108 | Returns: |
| 1109 | draft_tokens: torch.Tensor |
| 1110 | [batch_size * max_draft_len] |
| 1111 | Draft token ids. Flattened. |
| 1112 | ''' |
| 1113 | if (self.model_config is not None |
| 1114 | and hasattr(self.model_config, 'mapping') |
| 1115 | and self.model_config.mapping.tp_size |
| 1116 | > 1) and not (self.model_config.mapping.enable_attention_dp): |
| 1117 | combined = self.get_local_max_and_combined(logits) |
| 1118 | gathered = allgather(combined, self.model_config.mapping, dim=-1) |
| 1119 | draft_tokens = self.get_draft_tokens_from_gathered(gathered) |
| 1120 | elif (self.model_config is not None |
| 1121 | and hasattr(self.model_config, 'mapping') |
| 1122 | and self.model_config.mapping.tp_size |
| 1123 | > 1) and self.model_config.mapping.enable_lm_head_tp_in_adp: |
| 1124 | # For ADP + LM head TP mode, we need to find the global argmax across all TP ranks |
| 1125 | combined = self.get_local_max_and_combined(logits, |
| 1126 | mapping_lm_head_tp) |
| 1127 | gathered = allgather(combined, mapping_lm_head_tp, dim=-1) |
| 1128 | batch_size = logits.shape[0] |
| 1129 | local_batch_size = batch_size // mapping_lm_head_tp.tp_size |
| 1130 | gathered = gathered.view(mapping_lm_head_tp.tp_size, |
| 1131 | local_batch_size, -1) |
| 1132 | sliced_gathered = gathered[mapping_lm_head_tp.tp_rank] |
| 1133 | draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered) |
| 1134 | else: |
| 1135 | # Simple argmax if no TP or no model config |
| 1136 | draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32) |
| 1137 | |
| 1138 | return draft_tokens |
| 1139 | |
| 1140 | |
| 1141 | class MTPEagleWorker(MTPWorker): |
no test coverage detected