MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / MTPDecodingConfig

Class MTPDecodingConfig

tensorrt_llm/llmapi/llm_args.py:1110–1167  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

1108
1109
1110class MTPDecodingConfig(DecodingBaseConfig):
1111 num_nextn_predict_layers: int = 1
1112 use_relaxed_acceptance_for_thinking: bool = False
1113 relaxed_topk: int = 1
1114 relaxed_delta: float = 0.
1115 use_mtp_vanilla: bool = False
1116 mtp_eagle_one_model: bool = True
1117
1118 # TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers`
1119 # Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine.
1120 num_nextn_predict_layers_from_model_config: int = 1
1121
1122 # When encounter <think>, start thinking phase.
1123 # When encounter </think>, end thinking phase.
1124 # <think> [thinking phase] </think> [real output]
1125 begin_thinking_phase_token: int = 128798
1126 end_thinking_phase_token: int = 128799
1127
1128 def __init__(self, **kwargs):
1129 super().__init__(**kwargs)
1130 if 'num_nextn_predict_layers' in kwargs:
1131 self.max_draft_len = kwargs['num_nextn_predict_layers']
1132 self.max_total_draft_tokens = kwargs[
1133 'num_nextn_predict_layers'] # Current MTP only support linear tree
1134
1135 if not self.mtp_eagle_one_model:
1136 logger.warning(
1137 "2-model style MTP is deprecated. The mtp_eagle_one_model flag will do nothing "
1138 "in release 1.3. After that, the flag will be removed entirely."
1139 )
1140
1141 @classmethod
1142 def from_dict(cls, data: dict):
1143 out = cls(**data)
1144 out.max_draft_len = out.num_nextn_predict_layers
1145 out.max_total_draft_tokens = out.num_nextn_predict_layers # Current MTP only support linear tree
1146 return out
1147
1148 decoding_type: ClassVar[str] = "MTP"
1149
1150 def supports_backend(self, backend: str) -> bool:
1151 return backend == "pytorch"
1152
1153 @functools.cached_property
1154 def num_capture_layers(self) -> int:
1155 if not self.use_mtp_vanilla and not self.mtp_eagle_one_model:
1156 return 1
1157 return 0
1158
1159 @functools.cached_property
1160 def spec_dec_mode(self):
1161 from tensorrt_llm._torch.speculative.interface import \
1162 SpeculativeDecodingMode as TorchSpeculativeDecodingMode
1163 if self.num_nextn_predict_layers_from_model_config == 1 and not self.use_mtp_vanilla and self.mtp_eagle_one_model:
1164 return TorchSpeculativeDecodingMode.MTP_EAGLE_ONE_MODEL
1165 elif self.num_nextn_predict_layers_from_model_config == 1 and not self.use_mtp_vanilla and not self.mtp_eagle_one_model:
1166 return TorchSpeculativeDecodingMode.MTP_EAGLE
1167 return TorchSpeculativeDecodingMode.MTP

Calls

no outgoing calls