MCPcopy
hub / github.com/deepspeedai/DeepSpeed / TestDistInferenceAllReduce

Class TestDistInferenceAllReduce

tests/unit/comm/test_dist.py:133–149  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

131@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
132@pytest.mark.parametrize("num_elements", [128, 3])
133class TestDistInferenceAllReduce(DistributedTest):
134 device_count = get_accelerator().device_count()
135 if device_count >= 4:
136 world_size = [1, 2, 4]
137 elif device_count >= 2:
138 world_size = [1, 2]
139 else:
140 world_size = [1]
141
142 def test(self, dtype, num_elements):
143 x = torch.ones(1, num_elements).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
144 sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
145 result = torch.ones(1, num_elements).to(get_accelerator().device_name()) * sum_of_ranks
146 result = result.to(dtype)
147 x = x.to(dtype)
148 dist.inference_all_reduce(x)
149 assert torch.all(x == result)
150
151
152@pytest.mark.parametrize("dist_init_required", [True, False, None])

Callers

nothing calls this directly

Calls 2

get_acceleratorFunction · 0.90
device_countMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…