MCPcopy
hub / github.com/hpcaitech/ColossalAI / enable_spec_dec

Method enable_spec_dec

colossalai/inference/core/llm_engine.py:301–367  ·  view source on GitHub ↗

Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. Args: drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. If provided, the previous drafter and drafter model, if exist, will be o

(
        self,
        drafter_model: nn.Module = None,
        n_spec_tokens: int = None,
        use_glide_drafter: bool = False,
    )

Source from the content-addressed store, hash-verified

299 ), f"Model {self.model.__class__.__name__} is not supported."
300
301 def enable_spec_dec(
302 self,
303 drafter_model: nn.Module = None,
304 n_spec_tokens: int = None,
305 use_glide_drafter: bool = False,
306 ) -> None:
307 """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
308
309 Args:
310 drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
311 If provided, the previous drafter and drafter model, if exist, will be overwritten.
312 n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
313 If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
314 use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
315 If True, the drafter model will be replaced by a glide model.
316
317 ```python
318 ...
319 engine = InferenceEngine(model, tokenizer, inference_config)
320
321 engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
322 engine.generate(...) # Speculative Decoding
323
324 engine.disable_spec_dec()
325 engine.generate(...) # Normal generation
326
327 engine.enable_spec_dec()
328 engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
329 engine.clear_spec_dec()
330 ```
331 """
332
333 if drafter_model is None and self.drafter is None:
334 raise ValueError("Drafter not initialized. Please provide a Drafter Model")
335 if n_spec_tokens is not None:
336 assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
337 self.n_spec_tokens = n_spec_tokens
338 if drafter_model is not None:
339 assert isinstance(drafter_model, nn.Module)
340 # overwrite the drafter, if exists
341 self.clear_spec_dec()
342 self.drafter_model = drafter_model
343 self.drafter = Drafter(
344 self.drafter_model,
345 self.tokenizer,
346 device=self.device,
347 dtype=self.dtype,
348 )
349
350 # check if the provided drafter model is compatible with GLIDE structure
351 # when `use_glide_drafter` is set to True
352 if (
353 use_glide_drafter
354 and hasattr(drafter_model, "model")
355 and hasattr(drafter_model.model, "layers")
356 and hasattr(drafter_model.model.layers[0], "cross_attn")
357 ):
358 self.use_glide = use_glide_drafter

Callers 2

check_spec_decFunction · 0.80
inferFunction · 0.80

Calls 4

clear_spec_decMethod · 0.95
DrafterClass · 0.90
set_spec_dec_modeMethod · 0.80
warningMethod · 0.45

Tested by 1

check_spec_decFunction · 0.64