MCPcopy
hub / github.com/Audio-AGI/AudioSep / init_distributed_device

Function init_distributed_device

models/CLAP/training/distributed.py:70–150  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

68
69
70def init_distributed_device(args):
71 # Distributed training = training on more than one GPU.
72 # Works in both single and multi-node scenarios.
73 args.distributed = False
74 args.world_size = 1
75 args.rank = 0 # global rank
76 args.local_rank = 0
77 if args.horovod:
78 assert hvd is not None, "Horovod is not installed"
79 hvd.init()
80 world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
81 world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
82 local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
83 args.local_rank = local_rank
84 args.rank = world_rank
85 args.world_size = world_size
86 # args.local_rank = int(hvd.local_rank())
87 # args.rank = hvd.rank()
88 # args.world_size = hvd.size()
89 args.distributed = True
90 os.environ["LOCAL_RANK"] = str(args.local_rank)
91 os.environ["RANK"] = str(args.rank)
92 os.environ["WORLD_SIZE"] = str(args.world_size)
93 print(
94 f"Distributed training: local_rank={args.local_rank}, "
95 f"rank={args.rank}, world_size={args.world_size}, "
96 f"hostname={socket.gethostname()}, pid={os.getpid()}"
97 )
98 elif is_using_distributed():
99 if "SLURM_PROCID" in os.environ:
100 # DDP via SLURM
101 args.local_rank, args.rank, args.world_size = world_info_from_env()
102 # SLURM var -> torch.distributed vars in case needed
103 os.environ["LOCAL_RANK"] = str(args.local_rank)
104 os.environ["RANK"] = str(args.rank)
105 os.environ["WORLD_SIZE"] = str(args.world_size)
106 torch.distributed.init_process_group(
107 backend=args.dist_backend,
108 init_method=args.dist_url,
109 world_size=args.world_size,
110 rank=args.rank,
111 )
112 elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster
113 world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
114 world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
115 local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
116 args.local_rank = local_rank
117 args.rank = world_rank
118 args.world_size = world_size
119 torch.distributed.init_process_group(
120 backend=args.dist_backend,
121 init_method=args.dist_url,
122 world_size=args.world_size,
123 rank=args.rank,
124 )
125 else:
126 # DDP via torchrun, torch.distributed.launch
127 args.local_rank, _, _ = world_info_from_env()

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls 2

is_using_distributedFunction · 0.85
world_info_from_envFunction · 0.85

Tested by

no test coverage detected