MCPcopy
hub / github.com/InternLM/lmdeploy / DistConfig

Class DistConfig

lmdeploy/pytorch/config.py:130–221  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

128
129@dataclass
130class DistConfig:
131 dp: int = 1
132 ep: int = 1
133 dp_rank: int = 0
134 enable_microbatch: bool = False
135 enable_eplb: bool = False
136 world_size: int = 1
137
138 # tp
139 tp: int = 1 # default tp, equal to attn_tp
140 attn_tp: int = None # tp for attention
141 mlp_tp: int = None # tp for mlp
142 moe_tp: int = None # tp for moe
143
144 # tp mode
145 mlp_tp_mode: TPMode = TPMode.DEFAULT
146 moe_tp_mode: TPMode = TPMode.DEFAULT
147
148 def __post_init__(self):
149 """Post init."""
150 assert self.dp_rank < self.dp
151 assert self.dp >= 1
152
153 dp = self.dp
154 tp = self.tp
155 ep = self.ep
156
157 # ignore layer to for dp==1
158 if dp == 1:
159 self.mlp_tp = None
160 self.attn_tp = None
161 self.moe_tp = None
162
163 # mlp and moe tp
164 self.mlp_tp = self.mlp_tp or tp
165 self.moe_tp = self.moe_tp or (1 if ep > 1 else self.mlp_tp)
166
167 # world_size
168 world_size = ep if ep > 1 else max(self.mlp_tp, self.moe_tp)
169 self.world_size = world_size
170 assert (world_size >= dp and world_size % dp == 0), (f'world_size {world_size}, dp {dp}')
171 assert (world_size >= ep and world_size % ep == 0), (f'world_size {world_size}, ep {ep}')
172 assert (world_size >= self.mlp_tp
173 and world_size % self.mlp_tp == 0), (f'world_size {world_size}, mlp_tp {self.mlp_tp}')
174 assert (world_size >= self.moe_tp
175 and world_size % self.moe_tp == 0), (f'world_size {world_size}, moe_tp {self.moe_tp}')
176
177 # attn tp
178 self.attn_tp = self.attn_tp or self.world_size // dp
179 self.tp = self.attn_tp
180 if self.mlp_tp > 1:
181 assert (self.mlp_tp >= self.attn_tp
182 and self.mlp_tp % self.attn_tp == 0), (f'mlp_tp {self.mlp_tp}, attn_tp {self.attn_tp}')
183 if self.moe_tp > 1:
184 assert (self.moe_tp >= self.attn_tp
185 and self.moe_tp % self.attn_tp == 0), (f'moe_tp {self.moe_tp}, attn_tp {self.attn_tp}')
186 assert (world_size >= self.attn_tp
187 and world_size % self.attn_tp == 0), (f'world_size {world_size}, attn_tp {self.attn_tp}')

Calls

no outgoing calls