MCPcopy
hub / github.com/mudler/LocalAI / run_worker

Function run_worker

backend/python/mlx-distributed/backend.py:619–669  ·  view source on GitHub ↗

Worker loop for ranks > 0. Waits for commands from rank 0.

(group)

Source from the content-addressed store, hash-verified

617
618
619def run_worker(group):
620 """Worker loop for ranks > 0. Waits for commands from rank 0."""
621 from mlx_lm import load, stream_generate
622 from mlx_lm.sample_utils import make_sampler
623 from coordinator import DistributedCoordinator, CMD_LOAD_MODEL, CMD_GENERATE, CMD_SHUTDOWN
624 from sharding import pipeline_auto_parallel
625 import mlx.core as mx
626
627 coordinator = DistributedCoordinator(group)
628 model = None
629 tokenizer = None
630
631 print(f"[Rank {group.rank()}] Worker started, waiting for commands...", file=sys.stderr)
632
633 while True:
634 cmd, payload_size = coordinator.wait_for_command()
635
636 if cmd == CMD_LOAD_MODEL:
637 model_name = coordinator.broadcast_model_name()
638 print(f"[Rank {group.rank()}] Loading model: {model_name}", file=sys.stderr)
639 model, tokenizer = load(model_name)
640 model = pipeline_auto_parallel(model, group)
641 print(f"[Rank {group.rank()}] Model loaded and sharded", file=sys.stderr)
642
643 elif cmd == CMD_GENERATE:
644 if model is None:
645 print(f"[Rank {group.rank()}] No model loaded, skipping generate", file=sys.stderr)
646 continue
647
648 token_count = coordinator.broadcast_token_count(payload_size)
649 tokens_array = coordinator.broadcast_tokens([0] * token_count)
650 tokens = tokens_array.tolist()
651
652 gen_params = coordinator.broadcast_generation_params()
653
654 sampler = make_sampler(
655 temp=gen_params["temperature"],
656 top_p=gen_params["top_p"],
657 )
658
659 for _ in stream_generate(
660 model, tokenizer,
661 prompt=tokens,
662 max_tokens=gen_params["max_tokens"],
663 sampler=sampler,
664 ):
665 pass
666
667 elif cmd == CMD_SHUTDOWN:
668 print(f"[Rank {group.rank()}] Shutting down", file=sys.stderr)
669 break
670
671
672async def serve(address):

Callers 1

backend.pyFile · 0.85

Calls 8

wait_for_commandMethod · 0.95
broadcast_model_nameMethod · 0.95
broadcast_token_countMethod · 0.95
broadcast_tokensMethod · 0.95
pipeline_auto_parallelFunction · 0.90
loadFunction · 0.50

Tested by

no test coverage detected