Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. Args: drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. If provided, the previous drafter and drafter model, if exist, will be o
(
self,
drafter_model: nn.Module = None,
n_spec_tokens: int = None,
use_glide_drafter: bool = False,
)
| 299 | ), f"Model {self.model.__class__.__name__} is not supported." |
| 300 | |
| 301 | def enable_spec_dec( |
| 302 | self, |
| 303 | drafter_model: nn.Module = None, |
| 304 | n_spec_tokens: int = None, |
| 305 | use_glide_drafter: bool = False, |
| 306 | ) -> None: |
| 307 | """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. |
| 308 | |
| 309 | Args: |
| 310 | drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. |
| 311 | If provided, the previous drafter and drafter model, if exist, will be overwritten. |
| 312 | n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying. |
| 313 | If not provided, `max_n_spec_tokens` in InferenceConfig will be used. |
| 314 | use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False. |
| 315 | If True, the drafter model will be replaced by a glide model. |
| 316 | |
| 317 | ```python |
| 318 | ... |
| 319 | engine = InferenceEngine(model, tokenizer, inference_config) |
| 320 | |
| 321 | engine.enable_spec_dec(drafter_model, n_spec_tokens=5) |
| 322 | engine.generate(...) # Speculative Decoding |
| 323 | |
| 324 | engine.disable_spec_dec() |
| 325 | engine.generate(...) # Normal generation |
| 326 | |
| 327 | engine.enable_spec_dec() |
| 328 | engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens |
| 329 | engine.clear_spec_dec() |
| 330 | ``` |
| 331 | """ |
| 332 | |
| 333 | if drafter_model is None and self.drafter is None: |
| 334 | raise ValueError("Drafter not initialized. Please provide a Drafter Model") |
| 335 | if n_spec_tokens is not None: |
| 336 | assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens |
| 337 | self.n_spec_tokens = n_spec_tokens |
| 338 | if drafter_model is not None: |
| 339 | assert isinstance(drafter_model, nn.Module) |
| 340 | # overwrite the drafter, if exists |
| 341 | self.clear_spec_dec() |
| 342 | self.drafter_model = drafter_model |
| 343 | self.drafter = Drafter( |
| 344 | self.drafter_model, |
| 345 | self.tokenizer, |
| 346 | device=self.device, |
| 347 | dtype=self.dtype, |
| 348 | ) |
| 349 | |
| 350 | # check if the provided drafter model is compatible with GLIDE structure |
| 351 | # when `use_glide_drafter` is set to True |
| 352 | if ( |
| 353 | use_glide_drafter |
| 354 | and hasattr(drafter_model, "model") |
| 355 | and hasattr(drafter_model.model, "layers") |
| 356 | and hasattr(drafter_model.model.layers[0], "cross_attn") |
| 357 | ): |
| 358 | self.use_glide = use_glide_drafter |