MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / draft_sampler

Method draft_sampler

tensorrt_llm/_torch/speculative/mtp.py:1095–1138  ·  view source on GitHub ↗

Sampling draft tokens. Args: logits: torch.Tensor [num_tokens, vocab_size] Logits produced by the draft model. Returns: draft_tokens: torch.Tensor [batch_size * max_draft_len] Draft tok

(
        self,
        logits: torch.Tensor,
        mapping_lm_head_tp: Mapping = None,
    )

Source from the content-addressed store, hash-verified

1093 return draft_tokens
1094
1095 def draft_sampler(
1096 self,
1097 logits: torch.Tensor,
1098 mapping_lm_head_tp: Mapping = None,
1099 ):
1100 '''
1101 Sampling draft tokens.
1102
1103 Args:
1104 logits: torch.Tensor
1105 [num_tokens, vocab_size]
1106 Logits produced by the draft model.
1107
1108 Returns:
1109 draft_tokens: torch.Tensor
1110 [batch_size * max_draft_len]
1111 Draft token ids. Flattened.
1112 '''
1113 if (self.model_config is not None
1114 and hasattr(self.model_config, 'mapping')
1115 and self.model_config.mapping.tp_size
1116 > 1) and not (self.model_config.mapping.enable_attention_dp):
1117 combined = self.get_local_max_and_combined(logits)
1118 gathered = allgather(combined, self.model_config.mapping, dim=-1)
1119 draft_tokens = self.get_draft_tokens_from_gathered(gathered)
1120 elif (self.model_config is not None
1121 and hasattr(self.model_config, 'mapping')
1122 and self.model_config.mapping.tp_size
1123 > 1) and self.model_config.mapping.enable_lm_head_tp_in_adp:
1124 # For ADP + LM head TP mode, we need to find the global argmax across all TP ranks
1125 combined = self.get_local_max_and_combined(logits,
1126 mapping_lm_head_tp)
1127 gathered = allgather(combined, mapping_lm_head_tp, dim=-1)
1128 batch_size = logits.shape[0]
1129 local_batch_size = batch_size // mapping_lm_head_tp.tp_size
1130 gathered = gathered.view(mapping_lm_head_tp.tp_size,
1131 local_batch_size, -1)
1132 sliced_gathered = gathered[mapping_lm_head_tp.tp_rank]
1133 draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered)
1134 else:
1135 # Simple argmax if no TP or no model config
1136 draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
1137
1138 return draft_tokens
1139
1140
1141class MTPEagleWorker(MTPWorker):

Callers 2

forwardMethod · 0.95
forwardMethod · 0.80

Calls 4

allgatherFunction · 0.50
viewMethod · 0.45

Tested by

no test coverage detected