MCPcopy
hub / github.com/hpcaitech/ColossalAI / check_all2all_uneven

Function check_all2all_uneven

tests/test_fp8/test_all_to_all_single.py:32–62  ·  view source on GitHub ↗
(shape, dtype, async_op)

Source from the content-addressed store, hash-verified

30@parameterize("dtype", [torch.bfloat16, torch.float16])
31@parameterize("async_op", [True, False])
32def check_all2all_uneven(shape, dtype, async_op):
33 x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
34 input_split_sizes = [3, 3, 1, 1]
35 if dist.get_rank() in [0, 1]:
36 output_split_sizes = [3, 3, 3, 3]
37 else:
38 output_split_sizes = [1, 1, 1, 1]
39 output_shape = list(shape)
40 output_shape[0] = sum(output_split_sizes)
41 output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
42 output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype)
43 origin_hanle = dist.all_to_all_single(
44 output,
45 x,
46 output_split_sizes=output_split_sizes,
47 input_split_sizes=input_split_sizes,
48 group=_get_default_group(),
49 async_op=async_op,
50 )
51 fp8_handle = all_to_all_single_fp8(
52 output_fp8,
53 x,
54 output_split_sizes=output_split_sizes,
55 input_split_sizes=input_split_sizes,
56 group=_get_default_group(),
57 async_op=async_op,
58 )
59 if async_op:
60 origin_hanle.wait()
61 fp8_handle.wait()
62 assert_close(output, output_fp8, rtol=0.1, atol=0.1)
63
64
65def run_dist(rank, world_size, port):

Callers 1

run_distFunction · 0.85

Calls 6

get_acceleratorFunction · 0.90
all_to_all_single_fp8Function · 0.90
get_rankMethod · 0.80
emptyMethod · 0.80
get_current_deviceMethod · 0.45
waitMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…