Speculative decoding stats.
| 263 | # modify from vllm |
| 264 | @dataclass |
| 265 | class SpeculativeDecodingStats: |
| 266 | """Speculative decoding stats.""" |
| 267 | |
| 268 | num_spec_tokens: int |
| 269 | num_drafts: int = 0 |
| 270 | num_draft_tokens: int = 0 |
| 271 | num_accepted_tokens: int = 0 |
| 272 | num_accepted_tokens_per_pos: np.ndarray = None |
| 273 | |
| 274 | def __post_init__(self): |
| 275 | assert self.num_spec_tokens > 0 |
| 276 | self.num_accepted_tokens_per_pos = np.zeros(self.num_spec_tokens) |
| 277 | |
| 278 | def update_from_output(self, outputs: EngineOutput): |
| 279 | """Update from engine output.""" |
| 280 | spec_info = getattr(outputs.req_metrics, 'spec_info', None) |
| 281 | if spec_info: |
| 282 | self.num_drafts += 1 |
| 283 | self.num_draft_tokens += spec_info['num_draft_tokens'] |
| 284 | self.num_accepted_tokens += spec_info['num_accepted_tokens'] |
| 285 | self.num_accepted_tokens_per_pos[:spec_info['num_accepted_tokens']] += 1 |
| 286 | |
| 287 | def update_per_draft(self, num_draft_tokens: int, num_accepted_tokens: int): |
| 288 | """Update with per draft stats.""" |
| 289 | if num_draft_tokens > 0: |
| 290 | self.num_drafts += 1 |
| 291 | self.num_draft_tokens += num_draft_tokens |
| 292 | self.num_accepted_tokens += num_accepted_tokens |
| 293 | self.num_accepted_tokens_per_pos[:num_accepted_tokens] += 1 |
| 294 | |
| 295 | def __repr__(self): |
| 296 | draft_acceptance_rate = (self.num_accepted_tokens / self.num_draft_tokens * |
| 297 | 100 if self.num_draft_tokens > 0 else float('nan')) |
| 298 | |
| 299 | # conventionally, mean acceptance length includes the bonus token |
| 300 | mean_acceptance_length = 1 + (self.num_accepted_tokens / |
| 301 | self.num_drafts) if self.num_drafts > 0 else float('nan') |
| 302 | |
| 303 | acceptance_rates = self.num_accepted_tokens_per_pos / self.num_drafts if self.num_drafts > 0 else [ |
| 304 | float('nan') |
| 305 | ] * self.num_accepted_tokens |
| 306 | rates_str = ', '.join(f'{p:.3f}' for p in acceptance_rates) |
| 307 | |
| 308 | return ('SpeculativeDecodingStats(' |
| 309 | f'num_spec_tokens={self.num_spec_tokens}, ' |
| 310 | f'num_drafts={self.num_drafts}, ' |
| 311 | f'num_draft_tokens={self.num_draft_tokens}, ' |
| 312 | f'num_accepted_tokens={self.num_accepted_tokens}, ' |
| 313 | f'draft_acceptance_rate={draft_acceptance_rate:.2f}%, ' |
| 314 | f'mean_acceptance_length={mean_acceptance_length:.2f}, ' |
| 315 | f'per_position_acceptance_rate={rates_str})') |
no outgoing calls