(shape, dtype, fp8_format, async_op)
| 25 | @parameterize("fp8_format", ["e4m3", "e5m2"]) |
| 26 | @parameterize("async_op", [True, False]) |
| 27 | def check_4gpu(shape, dtype, fp8_format, async_op): |
| 28 | x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) |
| 29 | x_fp8 = x.clone() |
| 30 | origin_handle = dist.all_reduce(x, async_op=async_op) |
| 31 | fp8_handle = all_reduce_fp8(x_fp8, fp8_format=fp8_format, async_op=async_op) |
| 32 | if async_op: |
| 33 | origin_handle.wait() |
| 34 | fp8_handle.wait() |
| 35 | assert_close(x, x_fp8, rtol=0.1, atol=0.1) |
| 36 | |
| 37 | origin_handle = dist.all_reduce(x, op=dist.ReduceOp.AVG, async_op=async_op) |
| 38 | fp8_handle = all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format, async_op=async_op) |
| 39 | if async_op: |
| 40 | origin_handle.wait() |
| 41 | fp8_handle.wait() |
| 42 | assert_close(x, x_fp8, rtol=0.1, atol=0.1) |
| 43 | |
| 44 | |
| 45 | def run_dist(rank, world_size, port): |
no test coverage detected
searching dependent graphs…