(args: argparse.Namespace)
| 327 | |
| 328 | |
| 329 | def _run_mlx(args: argparse.Namespace) -> None: |
| 330 | from mlx_lm import stream_generate as stream_generate_baseline |
| 331 | from mlx_lm.sample_utils import make_sampler |
| 332 | |
| 333 | from .model_mlx import load, load_draft, stream_generate |
| 334 | |
| 335 | sampler = make_sampler(temp=args.temperature) |
| 336 | |
| 337 | logger.info(f"Loading target: {args.model}") |
| 338 | model, tokenizer = load(args.model) |
| 339 | logger.info(f"Loading draft: {args.draft_model}") |
| 340 | draft = load_draft(args.draft_model) |
| 341 | block_size = args.block_size if args.block_size is not None else int(draft.config.block_size) |
| 342 | |
| 343 | dataset = load_and_process_dataset(args.dataset) |
| 344 | dataset = _limit_dataset(dataset, args.max_samples) |
| 345 | |
| 346 | warmup_prompt = tokenizer.encode("Hi") |
| 347 | list(stream_generate_baseline(model, tokenizer, warmup_prompt, 3, sampler=sampler)) |
| 348 | list(stream_generate(model, draft, tokenizer, warmup_prompt, block_size, 3, sampler=sampler)) |
| 349 | |
| 350 | responses = [] |
| 351 | for idx in tqdm(range(len(dataset))): |
| 352 | instance = dataset[idx] |
| 353 | messages = [] |
| 354 | for user_content in instance["turns"]: |
| 355 | messages.append({"role": "user", "content": user_content}) |
| 356 | prompt = _apply_chat_template(tokenizer, messages, args.enable_thinking) |
| 357 | |
| 358 | response = {} |
| 359 | |
| 360 | tokens_bl, tps_bl = [], 0 |
| 361 | for r in stream_generate_baseline(model, tokenizer, prompt, args.max_new_tokens, sampler=sampler): |
| 362 | tokens_bl.append(r.token) |
| 363 | tps_bl = r.generation_tps |
| 364 | response[1] = _make_decode_metrics(len(tokens_bl), tps_bl, [1]) |
| 365 | |
| 366 | tokens_df, accs, tps_df = [], [], 0 |
| 367 | for r in stream_generate(model, draft, tokenizer, prompt, block_size, args.max_new_tokens, sampler=sampler): |
| 368 | tokens_df.extend(r.tokens) |
| 369 | accs.append(r.accepted) |
| 370 | tps_df = r.generation_tps |
| 371 | response[block_size] = _make_decode_metrics(len(tokens_df), tps_df, accs) |
| 372 | |
| 373 | output_text = tokenizer.decode(tokens_df) |
| 374 | messages.append({"role": "assistant", "content": output_text}) |
| 375 | responses.append(response) |
| 376 | |
| 377 | _print_decode_summary(responses, block_size) |
| 378 | |
| 379 | |
| 380 | def _run_server(args: argparse.Namespace) -> None: |
no test coverage detected