Worker loop for ranks > 0. Waits for commands from rank 0.
(group)
| 617 | |
| 618 | |
| 619 | def run_worker(group): |
| 620 | """Worker loop for ranks > 0. Waits for commands from rank 0.""" |
| 621 | from mlx_lm import load, stream_generate |
| 622 | from mlx_lm.sample_utils import make_sampler |
| 623 | from coordinator import DistributedCoordinator, CMD_LOAD_MODEL, CMD_GENERATE, CMD_SHUTDOWN |
| 624 | from sharding import pipeline_auto_parallel |
| 625 | import mlx.core as mx |
| 626 | |
| 627 | coordinator = DistributedCoordinator(group) |
| 628 | model = None |
| 629 | tokenizer = None |
| 630 | |
| 631 | print(f"[Rank {group.rank()}] Worker started, waiting for commands...", file=sys.stderr) |
| 632 | |
| 633 | while True: |
| 634 | cmd, payload_size = coordinator.wait_for_command() |
| 635 | |
| 636 | if cmd == CMD_LOAD_MODEL: |
| 637 | model_name = coordinator.broadcast_model_name() |
| 638 | print(f"[Rank {group.rank()}] Loading model: {model_name}", file=sys.stderr) |
| 639 | model, tokenizer = load(model_name) |
| 640 | model = pipeline_auto_parallel(model, group) |
| 641 | print(f"[Rank {group.rank()}] Model loaded and sharded", file=sys.stderr) |
| 642 | |
| 643 | elif cmd == CMD_GENERATE: |
| 644 | if model is None: |
| 645 | print(f"[Rank {group.rank()}] No model loaded, skipping generate", file=sys.stderr) |
| 646 | continue |
| 647 | |
| 648 | token_count = coordinator.broadcast_token_count(payload_size) |
| 649 | tokens_array = coordinator.broadcast_tokens([0] * token_count) |
| 650 | tokens = tokens_array.tolist() |
| 651 | |
| 652 | gen_params = coordinator.broadcast_generation_params() |
| 653 | |
| 654 | sampler = make_sampler( |
| 655 | temp=gen_params["temperature"], |
| 656 | top_p=gen_params["top_p"], |
| 657 | ) |
| 658 | |
| 659 | for _ in stream_generate( |
| 660 | model, tokenizer, |
| 661 | prompt=tokens, |
| 662 | max_tokens=gen_params["max_tokens"], |
| 663 | sampler=sampler, |
| 664 | ): |
| 665 | pass |
| 666 | |
| 667 | elif cmd == CMD_SHUTDOWN: |
| 668 | print(f"[Rank {group.rank()}] Shutting down", file=sys.stderr) |
| 669 | break |
| 670 | |
| 671 | |
| 672 | async def serve(address): |
no test coverage detected