MCPcopy
hub / github.com/MCG-NJU/VideoMAE / init_distributed_mode

Function init_distributed_mode

utils.py:249–290  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

247
248
249def init_distributed_mode(args):
250 if args.dist_on_itp:
251 args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
252 args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
253 args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
254 args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
255 os.environ['LOCAL_RANK'] = str(args.gpu)
256 os.environ['RANK'] = str(args.rank)
257 os.environ['WORLD_SIZE'] = str(args.world_size)
258 elif 'SLURM_PROCID' in os.environ:
259 args.rank = int(os.environ['SLURM_PROCID'])
260 args.gpu = int(os.environ['SLURM_LOCALID'])
261 args.world_size = int(os.environ['SLURM_NTASKS'])
262 os.environ['RANK'] = str(args.rank)
263 os.environ['LOCAL_RANK'] = str(args.gpu)
264 os.environ['WORLD_SIZE'] = str(args.world_size)
265
266 node_list = os.environ['SLURM_NODELIST']
267 addr = subprocess.getoutput(
268 f'scontrol show hostname {node_list} | head -n1')
269 if 'MASTER_ADDR' not in os.environ:
270 os.environ['MASTER_ADDR'] = addr
271 elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
272 args.rank = int(os.environ["RANK"])
273 args.world_size = int(os.environ['WORLD_SIZE'])
274 args.gpu = int(os.environ['LOCAL_RANK'])
275 else:
276 print('Not using distributed mode')
277 args.distributed = False
278 return
279
280 args.distributed = True
281
282 torch.cuda.set_device(args.gpu)
283 args.dist_backend = 'nccl'
284 print('| distributed init (rank {}): {}, gpu {}'.format(
285 args.rank, args.dist_url, args.gpu), flush=True)
286 torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
287 world_size=args.world_size, rank=args.rank)
288 torch.distributed.barrier()
289 # assert torch.distributed.is_initialized()
290 setup_for_distributed(args.rank == 0)
291
292
293def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):

Callers

nothing calls this directly

Calls 2

printFunction · 0.85
setup_for_distributedFunction · 0.85

Tested by

no test coverage detected