(shape, dtype, fp8_format)
| 11 | @parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32]) |
| 12 | @parameterize("fp8_format", ["e4m3", "e5m2"]) |
| 13 | def test_fp8_cast(shape, dtype, fp8_format): |
| 14 | x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) |
| 15 | ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format) |
| 16 | out = cast_from_fp8(ret, scale_inv, x.dtype) |
| 17 | assert_close(out, x, rtol=0.1, atol=0.1) |
| 18 | |
| 19 | if x.size(-1) % 2 == 0: |
| 20 | inp_dict = {"hidden_states": x.clone()} |
| 21 | cast_to_fp8_pipeline(inp_dict) |
| 22 | cast_from_fp8_pipeline(inp_dict) |
| 23 | assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1) |
| 24 | |
| 25 | |
| 26 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…