Apply pipeline parallelism to a model. Each rank only keeps its slice of layers. The first layer receives from the previous rank, and the last layer sends to the next rank. Args: model: The MLX model (must have model.layers or similar) group: The distributed group
(model, group, start_layer=None, end_layer=None)
| 95 | |
| 96 | |
| 97 | def pipeline_auto_parallel(model, group, start_layer=None, end_layer=None): |
| 98 | """Apply pipeline parallelism to a model. |
| 99 | |
| 100 | Each rank only keeps its slice of layers. The first layer receives from |
| 101 | the previous rank, and the last layer sends to the next rank. |
| 102 | |
| 103 | Args: |
| 104 | model: The MLX model (must have model.layers or similar) |
| 105 | group: The distributed group |
| 106 | start_layer: First layer index for this rank (auto-computed if None) |
| 107 | end_layer: Last layer index (exclusive) for this rank (auto-computed if None) |
| 108 | """ |
| 109 | rank = group.rank() |
| 110 | world_size = group.size() |
| 111 | |
| 112 | inner = get_inner_model(model) |
| 113 | layers = list(get_layers(inner)) |
| 114 | total_layers = len(layers) |
| 115 | |
| 116 | if start_layer is None or end_layer is None: |
| 117 | layers_per_rank = total_layers // world_size |
| 118 | remainder = total_layers % world_size |
| 119 | start_layer = rank * layers_per_rank + min(rank, remainder) |
| 120 | end_layer = start_layer + layers_per_rank + (1 if rank < remainder else 0) |
| 121 | |
| 122 | layers = layers[start_layer:end_layer] |
| 123 | for layer in layers: |
| 124 | mx.eval(layer) |
| 125 | |
| 126 | # Wrap first and last layers |
| 127 | layers[0] = PipelineFirstLayer(layers[0], rank, group=group) |
| 128 | layers[-1] = PipelineLastLayer(layers[-1], rank, world_size, group=group) |
| 129 | |
| 130 | # Replace layers on the inner model |
| 131 | if hasattr(inner, "layers"): |
| 132 | inner.layers = layers |
| 133 | elif hasattr(inner, "h"): |
| 134 | inner.h = layers |
| 135 | |
| 136 | return model |
no test coverage detected