(shape, dtype, async_op)
| 30 | @parameterize("dtype", [torch.bfloat16, torch.float16]) |
| 31 | @parameterize("async_op", [True, False]) |
| 32 | def check_all2all_uneven(shape, dtype, async_op): |
| 33 | x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) |
| 34 | input_split_sizes = [3, 3, 1, 1] |
| 35 | if dist.get_rank() in [0, 1]: |
| 36 | output_split_sizes = [3, 3, 3, 3] |
| 37 | else: |
| 38 | output_split_sizes = [1, 1, 1, 1] |
| 39 | output_shape = list(shape) |
| 40 | output_shape[0] = sum(output_split_sizes) |
| 41 | output = torch.empty(output_shape, device=x.device, dtype=x.dtype) |
| 42 | output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype) |
| 43 | origin_hanle = dist.all_to_all_single( |
| 44 | output, |
| 45 | x, |
| 46 | output_split_sizes=output_split_sizes, |
| 47 | input_split_sizes=input_split_sizes, |
| 48 | group=_get_default_group(), |
| 49 | async_op=async_op, |
| 50 | ) |
| 51 | fp8_handle = all_to_all_single_fp8( |
| 52 | output_fp8, |
| 53 | x, |
| 54 | output_split_sizes=output_split_sizes, |
| 55 | input_split_sizes=input_split_sizes, |
| 56 | group=_get_default_group(), |
| 57 | async_op=async_op, |
| 58 | ) |
| 59 | if async_op: |
| 60 | origin_hanle.wait() |
| 61 | fp8_handle.wait() |
| 62 | assert_close(output, output_fp8, rtol=0.1, atol=0.1) |
| 63 | |
| 64 | |
| 65 | def run_dist(rank, world_size, port): |
no test coverage detected
searching dependent graphs…