(
cls,
method: str,
num_speculative_tokens: int,
model: str,
target_cache_cfg: CacheConfig,
target_model: str = None,
dtype: str = 'auto',
trust_remote_code: bool = False,
model_format: str = None,
hf_overrides: dict[str, Any] = None,
dist_config: DistConfig = None,
)
| 576 | |
| 577 | @classmethod |
| 578 | def from_config( |
| 579 | cls, |
| 580 | method: str, |
| 581 | num_speculative_tokens: int, |
| 582 | model: str, |
| 583 | target_cache_cfg: CacheConfig, |
| 584 | target_model: str = None, |
| 585 | dtype: str = 'auto', |
| 586 | trust_remote_code: bool = False, |
| 587 | model_format: str = None, |
| 588 | hf_overrides: dict[str, Any] = None, |
| 589 | dist_config: DistConfig = None, |
| 590 | ): |
| 591 | model = model or target_model |
| 592 | dist_config = dist_config or DistConfig() |
| 593 | model_config = ModelConfig.from_pretrained(model, |
| 594 | trust_remote_code=trust_remote_code, |
| 595 | dtype=dtype, |
| 596 | dist_config=dist_config, |
| 597 | is_draft_model=True, |
| 598 | spec_method=method, |
| 599 | block_size=target_cache_cfg.block_size, |
| 600 | model_format=model_format, |
| 601 | hf_overrides=hf_overrides, |
| 602 | ) |
| 603 | cache_config = None |
| 604 | # include medusa |
| 605 | no_caches = ['medusa'] |
| 606 | if method not in no_caches: |
| 607 | cache_config = CacheConfig(max_batches=target_cache_cfg.max_batches, |
| 608 | block_size=target_cache_cfg.block_size, |
| 609 | kernel_block_size=target_cache_cfg.kernel_block_size, |
| 610 | num_cpu_blocks=target_cache_cfg.num_cpu_blocks, |
| 611 | num_gpu_blocks=target_cache_cfg.num_gpu_blocks, |
| 612 | cache_max_entry_count=target_cache_cfg.cache_max_entry_count, |
| 613 | max_prefill_token_num=target_cache_cfg.max_prefill_token_num, |
| 614 | device_type=target_cache_cfg.device_type, |
| 615 | quant_policy=target_cache_cfg.quant_policy, |
| 616 | migration_backend=target_cache_cfg.migration_backend) |
| 617 | obj = cls( |
| 618 | model=model, |
| 619 | method=method, |
| 620 | cache_config=cache_config, |
| 621 | model_config=model_config, |
| 622 | dist_config=dist_config, |
| 623 | num_speculative_tokens=num_speculative_tokens, |
| 624 | ) |
| 625 | return obj |
| 626 | |
| 627 | |
| 628 | @dataclass |
no test coverage detected