(self, request, context)
| 94 | return backend_pb2.Reply(message=bytes("OK", 'utf-8')) |
| 95 | |
| 96 | async def LoadModel(self, request, context): |
| 97 | try: |
| 98 | import mlx.core as mx |
| 99 | from mlx_lm import load |
| 100 | from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache |
| 101 | |
| 102 | print(f"[Rank 0] Loading model: {request.Model}", file=sys.stderr) |
| 103 | |
| 104 | self.options = parse_options(request.Options) |
| 105 | print(f"Options: {self.options}", file=sys.stderr) |
| 106 | |
| 107 | # Get distributed config from model options, falling back to env vars. |
| 108 | # If neither is set, run as single-node (no distributed). |
| 109 | hostfile = self.options.get("hostfile", os.environ.get("MLX_DISTRIBUTED_HOSTFILE", "")) |
| 110 | dist_backend = str(self.options.get("distributed_backend", |
| 111 | os.environ.get("MLX_DISTRIBUTED_BACKEND", "ring"))) |
| 112 | # JACCL coordinator: rank 0 reads from env (set by CLI --coordinator). |
| 113 | # Not in model options — rank 0 is the coordinator, workers get |
| 114 | # the address via their own --coordinator CLI flag. |
| 115 | jaccl_coordinator = os.environ.get("MLX_JACCL_COORDINATOR", "") |
| 116 | |
| 117 | if hostfile: |
| 118 | from coordinator import DistributedCoordinator, CMD_LOAD_MODEL |
| 119 | from sharding import pipeline_auto_parallel |
| 120 | |
| 121 | print(f"[Rank 0] Initializing distributed: backend={dist_backend}, hostfile={hostfile}", file=sys.stderr) |
| 122 | self.dist_backend = dist_backend |
| 123 | self.group = mlx_distributed_init( |
| 124 | rank=0, |
| 125 | hostfile=hostfile, |
| 126 | backend=dist_backend, |
| 127 | coordinator=jaccl_coordinator or None, |
| 128 | ) |
| 129 | self.coordinator = DistributedCoordinator(self.group) |
| 130 | self.coordinator.broadcast_command(CMD_LOAD_MODEL) |
| 131 | self.coordinator.broadcast_model_name(request.Model) |
| 132 | else: |
| 133 | print("[Rank 0] No hostfile configured, running single-node", file=sys.stderr) |
| 134 | |
| 135 | # Build tokenizer config from request and options |
| 136 | tokenizer_config = {} |
| 137 | if request.TrustRemoteCode or self.options.get("trust_remote_code", False): |
| 138 | tokenizer_config["trust_remote_code"] = True |
| 139 | # Token overrides from options |
| 140 | for key in ["eos_token", "pad_token", "bos_token", "unk_token", |
| 141 | "sep_token", "cls_token", "mask_token"]: |
| 142 | if key in self.options: |
| 143 | tokenizer_config[key] = self.options[key] |
| 144 | |
| 145 | if tokenizer_config: |
| 146 | print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr) |
| 147 | self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config) |
| 148 | else: |
| 149 | self.model, self.tokenizer = load(request.Model) |
| 150 | |
| 151 | if self.group is not None: |
| 152 | from sharding import pipeline_auto_parallel |
| 153 | self.model = pipeline_auto_parallel(self.model, self.group) |
nothing calls this directly
no test coverage detected