| 131 | @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) |
| 132 | @pytest.mark.parametrize("num_elements", [128, 3]) |
| 133 | class TestDistInferenceAllReduce(DistributedTest): |
| 134 | device_count = get_accelerator().device_count() |
| 135 | if device_count >= 4: |
| 136 | world_size = [1, 2, 4] |
| 137 | elif device_count >= 2: |
| 138 | world_size = [1, 2] |
| 139 | else: |
| 140 | world_size = [1] |
| 141 | |
| 142 | def test(self, dtype, num_elements): |
| 143 | x = torch.ones(1, num_elements).to(get_accelerator().device_name()) * (dist.get_rank() + 1) |
| 144 | sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2 |
| 145 | result = torch.ones(1, num_elements).to(get_accelerator().device_name()) * sum_of_ranks |
| 146 | result = result.to(dtype) |
| 147 | x = x.to(dtype) |
| 148 | dist.inference_all_reduce(x) |
| 149 | assert torch.all(x == result) |
| 150 | |
| 151 | |
| 152 | @pytest.mark.parametrize("dist_init_required", [True, False, None]) |
nothing calls this directly
no test coverage detected
searching dependent graphs…