MCPcopy
hub / github.com/z-lab/dflash / _run_mlx

Function _run_mlx

dflash/benchmark.py:329–377  ·  view source on GitHub ↗
(args: argparse.Namespace)

Source from the content-addressed store, hash-verified

327
328
329def _run_mlx(args: argparse.Namespace) -> None:
330 from mlx_lm import stream_generate as stream_generate_baseline
331 from mlx_lm.sample_utils import make_sampler
332
333 from .model_mlx import load, load_draft, stream_generate
334
335 sampler = make_sampler(temp=args.temperature)
336
337 logger.info(f"Loading target: {args.model}")
338 model, tokenizer = load(args.model)
339 logger.info(f"Loading draft: {args.draft_model}")
340 draft = load_draft(args.draft_model)
341 block_size = args.block_size if args.block_size is not None else int(draft.config.block_size)
342
343 dataset = load_and_process_dataset(args.dataset)
344 dataset = _limit_dataset(dataset, args.max_samples)
345
346 warmup_prompt = tokenizer.encode("Hi")
347 list(stream_generate_baseline(model, tokenizer, warmup_prompt, 3, sampler=sampler))
348 list(stream_generate(model, draft, tokenizer, warmup_prompt, block_size, 3, sampler=sampler))
349
350 responses = []
351 for idx in tqdm(range(len(dataset))):
352 instance = dataset[idx]
353 messages = []
354 for user_content in instance["turns"]:
355 messages.append({"role": "user", "content": user_content})
356 prompt = _apply_chat_template(tokenizer, messages, args.enable_thinking)
357
358 response = {}
359
360 tokens_bl, tps_bl = [], 0
361 for r in stream_generate_baseline(model, tokenizer, prompt, args.max_new_tokens, sampler=sampler):
362 tokens_bl.append(r.token)
363 tps_bl = r.generation_tps
364 response[1] = _make_decode_metrics(len(tokens_bl), tps_bl, [1])
365
366 tokens_df, accs, tps_df = [], [], 0
367 for r in stream_generate(model, draft, tokenizer, prompt, block_size, args.max_new_tokens, sampler=sampler):
368 tokens_df.extend(r.tokens)
369 accs.append(r.accepted)
370 tps_df = r.generation_tps
371 response[block_size] = _make_decode_metrics(len(tokens_df), tps_df, accs)
372
373 output_text = tokenizer.decode(tokens_df)
374 messages.append({"role": "assistant", "content": output_text})
375 responses.append(response)
376
377 _print_decode_summary(responses, block_size)
378
379
380def _run_server(args: argparse.Namespace) -> None:

Callers 1

mainFunction · 0.85

Calls 8

loadFunction · 0.85
load_draftFunction · 0.85
load_and_process_datasetFunction · 0.85
_limit_datasetFunction · 0.85
stream_generateFunction · 0.85
_apply_chat_templateFunction · 0.85
_make_decode_metricsFunction · 0.85
_print_decode_summaryFunction · 0.85

Tested by

no test coverage detected