()
| 478 | |
| 479 | |
| 480 | def main() -> None: |
| 481 | parser = argparse.ArgumentParser(description="DFlash benchmark") |
| 482 | parser.add_argument("--backend", choices=["transformers", "sglang", "vllm", "mlx"], required=True) |
| 483 | parser.add_argument("--model", type=str, required=True) |
| 484 | parser.add_argument("--dataset", type=str, required=True) |
| 485 | parser.add_argument("--max-new-tokens", type=int, default=2048) |
| 486 | parser.add_argument("--temperature", type=float, default=0.0) |
| 487 | |
| 488 | parser.add_argument("--draft-model", type=str, default=None) |
| 489 | parser.add_argument("--block-size", type=int, default=None) |
| 490 | parser.add_argument("--max-samples", type=int, default=None) |
| 491 | |
| 492 | parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000") |
| 493 | parser.add_argument("--num-prompts", type=int, default=1024) |
| 494 | parser.add_argument("--concurrency", type=int, default=1) |
| 495 | parser.add_argument("--top-p", type=float, default=1.0) |
| 496 | parser.add_argument("--top-k", type=int, default=1) |
| 497 | parser.add_argument("--enable-thinking", action="store_true") |
| 498 | parser.add_argument("--timeout-s", type=int, default=3600) |
| 499 | |
| 500 | args = parser.parse_args() |
| 501 | |
| 502 | assert not (args.enable_thinking and any(x in args.model.lower() for x in ["qwen3-4b", "qwen3-8b"])), ( |
| 503 | "DFlash draft models for Qwen3-4B and Qwen3-8B were not trained with thinking traces. " |
| 504 | "Using --enable-thinking will lead to suboptimal performance." |
| 505 | ) |
| 506 | |
| 507 | if args.backend == "transformers": |
| 508 | if args.draft_model is None: |
| 509 | parser.error("--draft-model is required for transformers backend") |
| 510 | _run_transformers(args) |
| 511 | elif args.backend == "mlx": |
| 512 | if args.draft_model is None: |
| 513 | parser.error("--draft-model is required for mlx backend") |
| 514 | _run_mlx(args) |
| 515 | else: |
| 516 | _run_server(args) |
| 517 | |
| 518 | |
| 519 | if __name__ == "__main__": |
no test coverage detected