MCPcopy Index your code
hub / github.com/mudler/LocalAI / BackendServicer

Class BackendServicer

backend/python/mlx-distributed/backend.py:75–616  ·  view source on GitHub ↗

gRPC servicer for distributed MLX inference (runs on rank 0). When started by LocalAI (server mode), distributed init happens at LoadModel time using config from model options or environment variables.

Source from the content-addressed store, hash-verified

73
74
75class BackendServicer(backend_pb2_grpc.BackendServicer):
76 """gRPC servicer for distributed MLX inference (runs on rank 0).
77
78 When started by LocalAI (server mode), distributed init happens at
79 LoadModel time using config from model options or environment variables.
80 """
81
82 def __init__(self):
83 self.group = None
84 self.dist_backend = None
85 self.model = None
86 self.tokenizer = None
87 self.coordinator = None
88 self.options = {}
89 self.lru_cache = None
90 self.model_key = None
91 self.max_kv_size = None
92
93 def Health(self, request, context):
94 return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
95
96 async def LoadModel(self, request, context):
97 try:
98 import mlx.core as mx
99 from mlx_lm import load
100 from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
101
102 print(f"[Rank 0] Loading model: {request.Model}", file=sys.stderr)
103
104 self.options = parse_options(request.Options)
105 print(f"Options: {self.options}", file=sys.stderr)
106
107 # Get distributed config from model options, falling back to env vars.
108 # If neither is set, run as single-node (no distributed).
109 hostfile = self.options.get("hostfile", os.environ.get("MLX_DISTRIBUTED_HOSTFILE", ""))
110 dist_backend = str(self.options.get("distributed_backend",
111 os.environ.get("MLX_DISTRIBUTED_BACKEND", "ring")))
112 # JACCL coordinator: rank 0 reads from env (set by CLI --coordinator).
113 # Not in model options — rank 0 is the coordinator, workers get
114 # the address via their own --coordinator CLI flag.
115 jaccl_coordinator = os.environ.get("MLX_JACCL_COORDINATOR", "")
116
117 if hostfile:
118 from coordinator import DistributedCoordinator, CMD_LOAD_MODEL
119 from sharding import pipeline_auto_parallel
120
121 print(f"[Rank 0] Initializing distributed: backend={dist_backend}, hostfile={hostfile}", file=sys.stderr)
122 self.dist_backend = dist_backend
123 self.group = mlx_distributed_init(
124 rank=0,
125 hostfile=hostfile,
126 backend=dist_backend,
127 coordinator=jaccl_coordinator or None,
128 )
129 self.coordinator = DistributedCoordinator(self.group)
130 self.coordinator.broadcast_command(CMD_LOAD_MODEL)
131 self.coordinator.broadcast_model_name(request.Model)
132 else:

Callers 1

serveFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected