get number of gpu per node
()
| 36 | |
| 37 | |
| 38 | def get_num_gpu_per_node(): |
| 39 | """ get number of gpu per node |
| 40 | """ |
| 41 | rank, world_size = get_dist_info() |
| 42 | if world_size == 1: |
| 43 | return 1 |
| 44 | local_rank = int(os.environ.get('LOCAL_RANK', '0')) |
| 45 | local_rank_tensor = torch.tensor([local_rank], device='cuda') |
| 46 | torch.distributed.all_reduce(local_rank_tensor, op=ReduceOp.MAX) |
| 47 | num_gpus = local_rank_tensor.tolist()[0] + 1 |
| 48 | |
| 49 | return num_gpus |
| 50 | |
| 51 | |
| 52 | def barrier(): |