gRPC servicer for distributed MLX inference (runs on rank 0). When started by LocalAI (server mode), distributed init happens at LoadModel time using config from model options or environment variables.
| 73 | |
| 74 | |
| 75 | class BackendServicer(backend_pb2_grpc.BackendServicer): |
| 76 | """gRPC servicer for distributed MLX inference (runs on rank 0). |
| 77 | |
| 78 | When started by LocalAI (server mode), distributed init happens at |
| 79 | LoadModel time using config from model options or environment variables. |
| 80 | """ |
| 81 | |
| 82 | def __init__(self): |
| 83 | self.group = None |
| 84 | self.dist_backend = None |
| 85 | self.model = None |
| 86 | self.tokenizer = None |
| 87 | self.coordinator = None |
| 88 | self.options = {} |
| 89 | self.lru_cache = None |
| 90 | self.model_key = None |
| 91 | self.max_kv_size = None |
| 92 | |
| 93 | def Health(self, request, context): |
| 94 | return backend_pb2.Reply(message=bytes("OK", 'utf-8')) |
| 95 | |
| 96 | async def LoadModel(self, request, context): |
| 97 | try: |
| 98 | import mlx.core as mx |
| 99 | from mlx_lm import load |
| 100 | from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache |
| 101 | |
| 102 | print(f"[Rank 0] Loading model: {request.Model}", file=sys.stderr) |
| 103 | |
| 104 | self.options = parse_options(request.Options) |
| 105 | print(f"Options: {self.options}", file=sys.stderr) |
| 106 | |
| 107 | # Get distributed config from model options, falling back to env vars. |
| 108 | # If neither is set, run as single-node (no distributed). |
| 109 | hostfile = self.options.get("hostfile", os.environ.get("MLX_DISTRIBUTED_HOSTFILE", "")) |
| 110 | dist_backend = str(self.options.get("distributed_backend", |
| 111 | os.environ.get("MLX_DISTRIBUTED_BACKEND", "ring"))) |
| 112 | # JACCL coordinator: rank 0 reads from env (set by CLI --coordinator). |
| 113 | # Not in model options — rank 0 is the coordinator, workers get |
| 114 | # the address via their own --coordinator CLI flag. |
| 115 | jaccl_coordinator = os.environ.get("MLX_JACCL_COORDINATOR", "") |
| 116 | |
| 117 | if hostfile: |
| 118 | from coordinator import DistributedCoordinator, CMD_LOAD_MODEL |
| 119 | from sharding import pipeline_auto_parallel |
| 120 | |
| 121 | print(f"[Rank 0] Initializing distributed: backend={dist_backend}, hostfile={hostfile}", file=sys.stderr) |
| 122 | self.dist_backend = dist_backend |
| 123 | self.group = mlx_distributed_init( |
| 124 | rank=0, |
| 125 | hostfile=hostfile, |
| 126 | backend=dist_backend, |
| 127 | coordinator=jaccl_coordinator or None, |
| 128 | ) |
| 129 | self.coordinator = DistributedCoordinator(self.group) |
| 130 | self.coordinator.broadcast_command(CMD_LOAD_MODEL) |
| 131 | self.coordinator.broadcast_model_name(request.Model) |
| 132 | else: |