Initialize torch.distributed.
(args)
| 177 | |
| 178 | |
| 179 | def initialize_distributed(args): |
| 180 | """Initialize torch.distributed.""" |
| 181 | if torch.distributed.is_initialized(): |
| 182 | if mpu.model_parallel_is_initialized(): |
| 183 | if args.model_parallel_size != mpu.get_model_parallel_world_size(): |
| 184 | raise ValueError( |
| 185 | "model_parallel_size is inconsistent with prior configuration." |
| 186 | "We currently do not support changing model_parallel_size." |
| 187 | ) |
| 188 | return False |
| 189 | else: |
| 190 | if args.model_parallel_size > 1: |
| 191 | warnings.warn( |
| 192 | "model_parallel_size > 1 but torch.distributed is not initialized via SAT." |
| 193 | "Please carefully make sure the correctness on your own." |
| 194 | ) |
| 195 | mpu.initialize_model_parallel(args.model_parallel_size) |
| 196 | return True |
| 197 | # the automatic assignment of devices has been moved to arguments.py |
| 198 | if args.device == "cpu": |
| 199 | pass |
| 200 | else: |
| 201 | torch.cuda.set_device(args.device) |
| 202 | # Call the init process |
| 203 | init_method = "tcp://" |
| 204 | args.master_ip = os.getenv("MASTER_ADDR", "localhost") |
| 205 | |
| 206 | if args.world_size == 1: |
| 207 | from sat.helpers import get_free_port |
| 208 | |
| 209 | default_master_port = str(get_free_port()) |
| 210 | else: |
| 211 | default_master_port = "6000" |
| 212 | args.master_port = os.getenv("MASTER_PORT", default_master_port) |
| 213 | init_method += args.master_ip + ":" + args.master_port |
| 214 | torch.distributed.init_process_group( |
| 215 | backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method |
| 216 | ) |
| 217 | |
| 218 | # Set the model-parallel / data-parallel communicators. |
| 219 | mpu.initialize_model_parallel(args.model_parallel_size) |
| 220 | |
| 221 | # Set vae context parallel group equal to model parallel group |
| 222 | from sgm.util import set_context_parallel_group, initialize_context_parallel |
| 223 | |
| 224 | if args.model_parallel_size <= 2: |
| 225 | set_context_parallel_group(args.model_parallel_size, mpu.get_model_parallel_group()) |
| 226 | else: |
| 227 | initialize_context_parallel(2) |
| 228 | # mpu.initialize_model_parallel(1) |
| 229 | # Optional DeepSpeed Activation Checkpointing Features |
| 230 | if args.deepspeed: |
| 231 | import deepspeed |
| 232 | |
| 233 | deepspeed.init_distributed( |
| 234 | dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method |
| 235 | ) |
| 236 | # # It seems that it has no negative influence to configure it even without using checkpointing. |
no test coverage detected