MCPcopy
hub / github.com/mudler/LocalAI / DistributedCoordinator

Class DistributedCoordinator

backend/python/mlx-distributed/coordinator.py:19–104  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

17
18
19class DistributedCoordinator:
20 def __init__(self, group):
21 self.group = group
22 self.rank = group.rank()
23 self.world_size = group.size()
24
25 def broadcast_command(self, cmd, payload_size=0):
26 """Rank 0 broadcasts a command to all ranks.
27
28 Uses all_sum with only rank 0 providing non-zero values so every
29 rank receives the same command array.
30 """
31 if self.rank == 0:
32 cmd_array = mx.array([cmd, payload_size], dtype=mx.int32)
33 else:
34 cmd_array = mx.zeros((2,), dtype=mx.int32)
35 result = mx.distributed.all_sum(cmd_array, group=self.group)
36 mx.eval(result)
37 return int(result[0].item()), int(result[1].item())
38
39 def broadcast_tokens(self, tokens):
40 """Broadcast input token ids from rank 0 to all ranks.
41
42 Rank 0 provides the real token array; other ranks provide zeros of the
43 same shape. ``all_sum`` ensures every rank ends up with identical data.
44 """
45 if self.rank == 0:
46 token_array = mx.array(tokens, dtype=mx.int32)
47 else:
48 token_array = mx.zeros((len(tokens),), dtype=mx.int32)
49 result = mx.distributed.all_sum(token_array, group=self.group)
50 mx.eval(result)
51 return result
52
53 def broadcast_token_count(self, count):
54 """Broadcast the number of tokens so workers can prepare a buffer."""
55 if self.rank == 0:
56 count_array = mx.array([count], dtype=mx.int32)
57 else:
58 count_array = mx.zeros((1,), dtype=mx.int32)
59 result = mx.distributed.all_sum(count_array, group=self.group)
60 mx.eval(result)
61 return int(result[0].item())
62
63 def broadcast_generation_params(self, max_tokens=200, temperature=0.6, top_p=1.0):
64 """Broadcast generation parameters from rank 0."""
65 if self.rank == 0:
66 params = mx.array([max_tokens, temperature, top_p], dtype=mx.float32)
67 else:
68 params = mx.zeros((3,), dtype=mx.float32)
69 result = mx.distributed.all_sum(params, group=self.group)
70 mx.eval(result)
71 return {
72 "max_tokens": int(result[0].item()),
73 "temperature": float(result[1].item()),
74 "top_p": float(result[2].item()),
75 }
76

Callers 2

LoadModelMethod · 0.90
run_workerFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected