| 128 | |
| 129 | @dataclass |
| 130 | class DistConfig: |
| 131 | dp: int = 1 |
| 132 | ep: int = 1 |
| 133 | dp_rank: int = 0 |
| 134 | enable_microbatch: bool = False |
| 135 | enable_eplb: bool = False |
| 136 | world_size: int = 1 |
| 137 | |
| 138 | # tp |
| 139 | tp: int = 1 # default tp, equal to attn_tp |
| 140 | attn_tp: int = None # tp for attention |
| 141 | mlp_tp: int = None # tp for mlp |
| 142 | moe_tp: int = None # tp for moe |
| 143 | |
| 144 | # tp mode |
| 145 | mlp_tp_mode: TPMode = TPMode.DEFAULT |
| 146 | moe_tp_mode: TPMode = TPMode.DEFAULT |
| 147 | |
| 148 | def __post_init__(self): |
| 149 | """Post init.""" |
| 150 | assert self.dp_rank < self.dp |
| 151 | assert self.dp >= 1 |
| 152 | |
| 153 | dp = self.dp |
| 154 | tp = self.tp |
| 155 | ep = self.ep |
| 156 | |
| 157 | # ignore layer to for dp==1 |
| 158 | if dp == 1: |
| 159 | self.mlp_tp = None |
| 160 | self.attn_tp = None |
| 161 | self.moe_tp = None |
| 162 | |
| 163 | # mlp and moe tp |
| 164 | self.mlp_tp = self.mlp_tp or tp |
| 165 | self.moe_tp = self.moe_tp or (1 if ep > 1 else self.mlp_tp) |
| 166 | |
| 167 | # world_size |
| 168 | world_size = ep if ep > 1 else max(self.mlp_tp, self.moe_tp) |
| 169 | self.world_size = world_size |
| 170 | assert (world_size >= dp and world_size % dp == 0), (f'world_size {world_size}, dp {dp}') |
| 171 | assert (world_size >= ep and world_size % ep == 0), (f'world_size {world_size}, ep {ep}') |
| 172 | assert (world_size >= self.mlp_tp |
| 173 | and world_size % self.mlp_tp == 0), (f'world_size {world_size}, mlp_tp {self.mlp_tp}') |
| 174 | assert (world_size >= self.moe_tp |
| 175 | and world_size % self.moe_tp == 0), (f'world_size {world_size}, moe_tp {self.moe_tp}') |
| 176 | |
| 177 | # attn tp |
| 178 | self.attn_tp = self.attn_tp or self.world_size // dp |
| 179 | self.tp = self.attn_tp |
| 180 | if self.mlp_tp > 1: |
| 181 | assert (self.mlp_tp >= self.attn_tp |
| 182 | and self.mlp_tp % self.attn_tp == 0), (f'mlp_tp {self.mlp_tp}, attn_tp {self.attn_tp}') |
| 183 | if self.moe_tp > 1: |
| 184 | assert (self.moe_tp >= self.attn_tp |
| 185 | and self.moe_tp % self.attn_tp == 0), (f'moe_tp {self.moe_tp}, attn_tp {self.attn_tp}') |
| 186 | assert (world_size >= self.attn_tp |
| 187 | and world_size % self.attn_tp == 0), (f'world_size {world_size}, attn_tp {self.attn_tp}') |
no outgoing calls