(
model, draft, tokenizer, prompt,
block_size=None, max_tokens=256, temperature=0.0, sampler=None,
)
| 427 | |
| 428 | |
| 429 | def stream_generate( |
| 430 | model, draft, tokenizer, prompt, |
| 431 | block_size=None, max_tokens=256, temperature=0.0, sampler=None, |
| 432 | ): |
| 433 | _patch_model(model, draft.config.target_layer_ids) |
| 434 | block_size = block_size if block_size is not None else int(draft.config.block_size) |
| 435 | sampler = sampler or make_sampler(temp=temperature) |
| 436 | |
| 437 | if not isinstance(tokenizer, TokenizerWrapper): |
| 438 | tokenizer = TokenizerWrapper(tokenizer) |
| 439 | |
| 440 | if not isinstance(prompt, mx.array): |
| 441 | if isinstance(prompt, str): |
| 442 | add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token) |
| 443 | prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens) |
| 444 | prompt = mx.array(prompt) |
| 445 | |
| 446 | detokenizer = tokenizer.detokenizer |
| 447 | mask_id = int(draft.config.mask_token_id) |
| 448 | tokens = prompt.tolist() |
| 449 | |
| 450 | target_cache = make_prompt_cache(model) |
| 451 | draft_cache = make_prompt_cache(draft) |
| 452 | draft.bind(model) |
| 453 | _target_can_trim = can_trim_prompt_cache(target_cache) |
| 454 | if not _target_can_trim and not _HAS_GDN: |
| 455 | raise RuntimeError( |
| 456 | "This MLX model requires gated-delta rollback support, but " |
| 457 | "mlx_lm.models.gated_delta is unavailable." |
| 458 | ) |
| 459 | _capture = _GDNStateCapture() if not _target_can_trim else None |
| 460 | |
| 461 | try: |
| 462 | tic = time.perf_counter() |
| 463 | with mx.stream(generation_stream): |
| 464 | logits = model(prompt[None], target_cache) |
| 465 | hidden = mx.concatenate(model._hidden_states, axis=-1) |
| 466 | mx.eval(logits, hidden) |
| 467 | prompt_tps = prompt.size / (time.perf_counter() - tic) |
| 468 | |
| 469 | tic = time.perf_counter() |
| 470 | token = sampler(logits[:, -1:])[0, 0].item() |
| 471 | tokens.append(token) |
| 472 | n = 1 |
| 473 | |
| 474 | if token in tokenizer.eos_token_ids: |
| 475 | detokenizer.add_token(token) |
| 476 | detokenizer.finalize() |
| 477 | yield _make_response(detokenizer.last_segment, [token], 1, prompt.size, prompt_tps, n, tic, "stop") |
| 478 | return |
| 479 | |
| 480 | detokenizer.add_token(token) |
| 481 | yield _make_response( |
| 482 | detokenizer.last_segment, |
| 483 | [token], |
| 484 | 1, |
| 485 | prompt.size, |
| 486 | prompt_tps, |
no test coverage detected