Initialize MLX distributed runtime. Ring: MLX_HOSTFILE points to a JSON array of "ip:port" strings. Each rank binds to its own entry (hostfile[rank]) and connects to neighbors for the ring pipeline. JACCL: MLX_IBV_DEVICES points to a JSON 2D matrix of RDMA device names. MLX_JAC
(rank, hostfile, backend="ring", coordinator=None)
| 39 | |
| 40 | |
| 41 | def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None): |
| 42 | """Initialize MLX distributed runtime. |
| 43 | |
| 44 | Ring: MLX_HOSTFILE points to a JSON array of "ip:port" strings. Each rank |
| 45 | binds to its own entry (hostfile[rank]) and connects to neighbors for the |
| 46 | ring pipeline. |
| 47 | |
| 48 | JACCL: MLX_IBV_DEVICES points to a JSON 2D matrix of RDMA device names. |
| 49 | MLX_JACCL_COORDINATOR is rank 0's ip:port where it runs a TCP service that |
| 50 | helps all ranks establish RDMA connections. |
| 51 | """ |
| 52 | import mlx.core as mx |
| 53 | |
| 54 | if backend == "ring": |
| 55 | os.environ["MLX_HOSTFILE"] = hostfile |
| 56 | os.environ["MLX_RANK"] = str(rank) |
| 57 | os.environ["MLX_RING_VERBOSE"] = "1" |
| 58 | return mx.distributed.init(backend="ring", strict=True) |
| 59 | elif backend == "jaccl": |
| 60 | os.environ["MLX_IBV_DEVICES"] = hostfile |
| 61 | os.environ["MLX_RANK"] = str(rank) |
| 62 | if coordinator: |
| 63 | os.environ["MLX_JACCL_COORDINATOR"] = coordinator |
| 64 | return mx.distributed.init(backend="jaccl", strict=True) |
| 65 | else: |
| 66 | raise ValueError(f"Unknown backend: {backend}") |
| 67 | |
| 68 | |
| 69 | # Re-export the shared helper under the local name for back-compat with |
no test coverage detected