MCPcopy
hub / github.com/z-lab/dflash / stream_generate

Function stream_generate

dflash/model_mlx.py:429–582  ·  view source on GitHub ↗
(
    model, draft, tokenizer, prompt,
    block_size=None, max_tokens=256, temperature=0.0, sampler=None,
)

Source from the content-addressed store, hash-verified

427
428
429def stream_generate(
430 model, draft, tokenizer, prompt,
431 block_size=None, max_tokens=256, temperature=0.0, sampler=None,
432):
433 _patch_model(model, draft.config.target_layer_ids)
434 block_size = block_size if block_size is not None else int(draft.config.block_size)
435 sampler = sampler or make_sampler(temp=temperature)
436
437 if not isinstance(tokenizer, TokenizerWrapper):
438 tokenizer = TokenizerWrapper(tokenizer)
439
440 if not isinstance(prompt, mx.array):
441 if isinstance(prompt, str):
442 add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
443 prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
444 prompt = mx.array(prompt)
445
446 detokenizer = tokenizer.detokenizer
447 mask_id = int(draft.config.mask_token_id)
448 tokens = prompt.tolist()
449
450 target_cache = make_prompt_cache(model)
451 draft_cache = make_prompt_cache(draft)
452 draft.bind(model)
453 _target_can_trim = can_trim_prompt_cache(target_cache)
454 if not _target_can_trim and not _HAS_GDN:
455 raise RuntimeError(
456 "This MLX model requires gated-delta rollback support, but "
457 "mlx_lm.models.gated_delta is unavailable."
458 )
459 _capture = _GDNStateCapture() if not _target_can_trim else None
460
461 try:
462 tic = time.perf_counter()
463 with mx.stream(generation_stream):
464 logits = model(prompt[None], target_cache)
465 hidden = mx.concatenate(model._hidden_states, axis=-1)
466 mx.eval(logits, hidden)
467 prompt_tps = prompt.size / (time.perf_counter() - tic)
468
469 tic = time.perf_counter()
470 token = sampler(logits[:, -1:])[0, 0].item()
471 tokens.append(token)
472 n = 1
473
474 if token in tokenizer.eos_token_ids:
475 detokenizer.add_token(token)
476 detokenizer.finalize()
477 yield _make_response(detokenizer.last_segment, [token], 1, prompt.size, prompt_tps, n, tic, "stop")
478 return
479
480 detokenizer.add_token(token)
481 yield _make_response(
482 detokenizer.last_segment,
483 [token],
484 1,
485 prompt.size,
486 prompt_tps,

Callers 1

_run_mlxFunction · 0.85

Calls 8

_patch_modelFunction · 0.85
_GDNStateCaptureClass · 0.85
_make_responseFunction · 0.85
_trim_recent_cacheFunction · 0.85
bindMethod · 0.80
clearMethod · 0.80
rollbackMethod · 0.80
closeMethod · 0.80

Tested by

no test coverage detected