(rank, world_size, port)
| 9 | |
| 10 | |
| 11 | def check_layer(rank, world_size, port): |
| 12 | launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") |
| 13 | |
| 14 | physical_mesh_id = torch.arange(0, 4) |
| 15 | assert rank == dist.get_rank() |
| 16 | |
| 17 | tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda() |
| 18 | mesh_shape = (2, 2) |
| 19 | # [[0, 1, |
| 20 | # [2, 3]] |
| 21 | device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) |
| 22 | |
| 23 | for axis in range(len(mesh_shape)): |
| 24 | tensor = torch.ones(4).cuda() |
| 25 | pg = device_mesh.get_process_group(axis=axis) |
| 26 | dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg) |
| 27 | assert tensor.equal(tensor_to_check) |
| 28 | |
| 29 | |
| 30 | @pytest.mark.dist |
nothing calls this directly
no test coverage detected
searching dependent graphs…