| 1108 | |
| 1109 | |
| 1110 | class MTPDecodingConfig(DecodingBaseConfig): |
| 1111 | num_nextn_predict_layers: int = 1 |
| 1112 | use_relaxed_acceptance_for_thinking: bool = False |
| 1113 | relaxed_topk: int = 1 |
| 1114 | relaxed_delta: float = 0. |
| 1115 | use_mtp_vanilla: bool = False |
| 1116 | mtp_eagle_one_model: bool = True |
| 1117 | |
| 1118 | # TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers` |
| 1119 | # Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine. |
| 1120 | num_nextn_predict_layers_from_model_config: int = 1 |
| 1121 | |
| 1122 | # When encounter <think>, start thinking phase. |
| 1123 | # When encounter </think>, end thinking phase. |
| 1124 | # <think> [thinking phase] </think> [real output] |
| 1125 | begin_thinking_phase_token: int = 128798 |
| 1126 | end_thinking_phase_token: int = 128799 |
| 1127 | |
| 1128 | def __init__(self, **kwargs): |
| 1129 | super().__init__(**kwargs) |
| 1130 | if 'num_nextn_predict_layers' in kwargs: |
| 1131 | self.max_draft_len = kwargs['num_nextn_predict_layers'] |
| 1132 | self.max_total_draft_tokens = kwargs[ |
| 1133 | 'num_nextn_predict_layers'] # Current MTP only support linear tree |
| 1134 | |
| 1135 | if not self.mtp_eagle_one_model: |
| 1136 | logger.warning( |
| 1137 | "2-model style MTP is deprecated. The mtp_eagle_one_model flag will do nothing " |
| 1138 | "in release 1.3. After that, the flag will be removed entirely." |
| 1139 | ) |
| 1140 | |
| 1141 | @classmethod |
| 1142 | def from_dict(cls, data: dict): |
| 1143 | out = cls(**data) |
| 1144 | out.max_draft_len = out.num_nextn_predict_layers |
| 1145 | out.max_total_draft_tokens = out.num_nextn_predict_layers # Current MTP only support linear tree |
| 1146 | return out |
| 1147 | |
| 1148 | decoding_type: ClassVar[str] = "MTP" |
| 1149 | |
| 1150 | def supports_backend(self, backend: str) -> bool: |
| 1151 | return backend == "pytorch" |
| 1152 | |
| 1153 | @functools.cached_property |
| 1154 | def num_capture_layers(self) -> int: |
| 1155 | if not self.use_mtp_vanilla and not self.mtp_eagle_one_model: |
| 1156 | return 1 |
| 1157 | return 0 |
| 1158 | |
| 1159 | @functools.cached_property |
| 1160 | def spec_dec_mode(self): |
| 1161 | from tensorrt_llm._torch.speculative.interface import \ |
| 1162 | SpeculativeDecodingMode as TorchSpeculativeDecodingMode |
| 1163 | if self.num_nextn_predict_layers_from_model_config == 1 and not self.use_mtp_vanilla and self.mtp_eagle_one_model: |
| 1164 | return TorchSpeculativeDecodingMode.MTP_EAGLE_ONE_MODEL |
| 1165 | elif self.num_nextn_predict_layers_from_model_config == 1 and not self.use_mtp_vanilla and not self.mtp_eagle_one_model: |
| 1166 | return TorchSpeculativeDecodingMode.MTP_EAGLE |
| 1167 | return TorchSpeculativeDecodingMode.MTP |
no outgoing calls