MCPcopy
hub / github.com/zai-org/CogVideo / initialize_distributed

Function initialize_distributed

sat/arguments.py:179–253  ·  view source on GitHub ↗

Initialize torch.distributed.

(args)

Source from the content-addressed store, hash-verified

177
178
179def initialize_distributed(args):
180 """Initialize torch.distributed."""
181 if torch.distributed.is_initialized():
182 if mpu.model_parallel_is_initialized():
183 if args.model_parallel_size != mpu.get_model_parallel_world_size():
184 raise ValueError(
185 "model_parallel_size is inconsistent with prior configuration."
186 "We currently do not support changing model_parallel_size."
187 )
188 return False
189 else:
190 if args.model_parallel_size > 1:
191 warnings.warn(
192 "model_parallel_size > 1 but torch.distributed is not initialized via SAT."
193 "Please carefully make sure the correctness on your own."
194 )
195 mpu.initialize_model_parallel(args.model_parallel_size)
196 return True
197 # the automatic assignment of devices has been moved to arguments.py
198 if args.device == "cpu":
199 pass
200 else:
201 torch.cuda.set_device(args.device)
202 # Call the init process
203 init_method = "tcp://"
204 args.master_ip = os.getenv("MASTER_ADDR", "localhost")
205
206 if args.world_size == 1:
207 from sat.helpers import get_free_port
208
209 default_master_port = str(get_free_port())
210 else:
211 default_master_port = "6000"
212 args.master_port = os.getenv("MASTER_PORT", default_master_port)
213 init_method += args.master_ip + ":" + args.master_port
214 torch.distributed.init_process_group(
215 backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
216 )
217
218 # Set the model-parallel / data-parallel communicators.
219 mpu.initialize_model_parallel(args.model_parallel_size)
220
221 # Set vae context parallel group equal to model parallel group
222 from sgm.util import set_context_parallel_group, initialize_context_parallel
223
224 if args.model_parallel_size <= 2:
225 set_context_parallel_group(args.model_parallel_size, mpu.get_model_parallel_group())
226 else:
227 initialize_context_parallel(2)
228 # mpu.initialize_model_parallel(1)
229 # Optional DeepSpeed Activation Checkpointing Features
230 if args.deepspeed:
231 import deepspeed
232
233 deepspeed.init_distributed(
234 dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
235 )
236 # # It seems that it has no negative influence to configure it even without using checkpointing.

Callers 1

get_argsFunction · 0.85

Calls 2

Tested by

no test coverage detected