(shape, dtype, fp8_format)
| 16 | @parameterize("dtype", [torch.bfloat16, torch.float16]) |
| 17 | @parameterize("fp8_format", ["e4m3", "e5m2"]) |
| 18 | def check_4gpu(shape, dtype, fp8_format): |
| 19 | x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) |
| 20 | output = torch.empty_like(x) |
| 21 | output_fp8 = torch.empty_like(x) |
| 22 | all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), fp8_format=fp8_format) |
| 23 | dist.all_to_all_single(output, x, group=_get_default_group()) |
| 24 | assert_close(output, output_fp8, rtol=0.1, atol=0.1) |
| 25 | |
| 26 | |
| 27 | def run_dist(rank, world_size, port): |
no test coverage detected
searching dependent graphs…