| 17 | |
| 18 | |
| 19 | class DistributedCoordinator: |
| 20 | def __init__(self, group): |
| 21 | self.group = group |
| 22 | self.rank = group.rank() |
| 23 | self.world_size = group.size() |
| 24 | |
| 25 | def broadcast_command(self, cmd, payload_size=0): |
| 26 | """Rank 0 broadcasts a command to all ranks. |
| 27 | |
| 28 | Uses all_sum with only rank 0 providing non-zero values so every |
| 29 | rank receives the same command array. |
| 30 | """ |
| 31 | if self.rank == 0: |
| 32 | cmd_array = mx.array([cmd, payload_size], dtype=mx.int32) |
| 33 | else: |
| 34 | cmd_array = mx.zeros((2,), dtype=mx.int32) |
| 35 | result = mx.distributed.all_sum(cmd_array, group=self.group) |
| 36 | mx.eval(result) |
| 37 | return int(result[0].item()), int(result[1].item()) |
| 38 | |
| 39 | def broadcast_tokens(self, tokens): |
| 40 | """Broadcast input token ids from rank 0 to all ranks. |
| 41 | |
| 42 | Rank 0 provides the real token array; other ranks provide zeros of the |
| 43 | same shape. ``all_sum`` ensures every rank ends up with identical data. |
| 44 | """ |
| 45 | if self.rank == 0: |
| 46 | token_array = mx.array(tokens, dtype=mx.int32) |
| 47 | else: |
| 48 | token_array = mx.zeros((len(tokens),), dtype=mx.int32) |
| 49 | result = mx.distributed.all_sum(token_array, group=self.group) |
| 50 | mx.eval(result) |
| 51 | return result |
| 52 | |
| 53 | def broadcast_token_count(self, count): |
| 54 | """Broadcast the number of tokens so workers can prepare a buffer.""" |
| 55 | if self.rank == 0: |
| 56 | count_array = mx.array([count], dtype=mx.int32) |
| 57 | else: |
| 58 | count_array = mx.zeros((1,), dtype=mx.int32) |
| 59 | result = mx.distributed.all_sum(count_array, group=self.group) |
| 60 | mx.eval(result) |
| 61 | return int(result[0].item()) |
| 62 | |
| 63 | def broadcast_generation_params(self, max_tokens=200, temperature=0.6, top_p=1.0): |
| 64 | """Broadcast generation parameters from rank 0.""" |
| 65 | if self.rank == 0: |
| 66 | params = mx.array([max_tokens, temperature, top_p], dtype=mx.float32) |
| 67 | else: |
| 68 | params = mx.zeros((3,), dtype=mx.float32) |
| 69 | result = mx.distributed.all_sum(params, group=self.group) |
| 70 | mx.eval(result) |
| 71 | return { |
| 72 | "max_tokens": int(result[0].item()), |
| 73 | "temperature": float(result[1].item()), |
| 74 | "top_p": float(result[2].item()), |
| 75 | } |
| 76 |
no outgoing calls
no test coverage detected