(rank, world_size, port)
| 177 | |
| 178 | |
| 179 | def check_comm(rank, world_size, port): |
| 180 | disable_existing_loggers() |
| 181 | launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") |
| 182 | |
| 183 | physical_mesh_id = torch.arange(0, 4) |
| 184 | assert rank == dist.get_rank() |
| 185 | |
| 186 | mesh_shape = (2, 2) |
| 187 | # [[0, 1, |
| 188 | # [2, 3]] |
| 189 | device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) |
| 190 | # test all gather |
| 191 | check_all_gather(device_mesh, rank) |
| 192 | |
| 193 | # test shard |
| 194 | check_shard(device_mesh, rank) |
| 195 | |
| 196 | # test all to all |
| 197 | check_all_to_all(device_mesh, rank) |
| 198 | |
| 199 | # test all reduce |
| 200 | check_all_reduce_fwd(device_mesh, rank) |
| 201 | check_all_reduce_bwd(device_mesh, rank) |
| 202 | |
| 203 | # test all reduce in 1D flatten device mesh |
| 204 | check_all_reduce_in_flatten_device_mesh(device_mesh, rank) |
| 205 | |
| 206 | |
| 207 | @pytest.mark.dist |
nothing calls this directly
no test coverage detected
searching dependent graphs…