MCPcopy
hub / github.com/hpcaitech/ColossalAI / test_fp8_cast

Function test_fp8_cast

tests/test_fp8/test_fp8_cast.py:13–23  ·  view source on GitHub ↗
(shape, dtype, fp8_format)

Source from the content-addressed store, hash-verified

11@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
12@parameterize("fp8_format", ["e4m3", "e5m2"])
13def test_fp8_cast(shape, dtype, fp8_format):
14 x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
15 ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format)
16 out = cast_from_fp8(ret, scale_inv, x.dtype)
17 assert_close(out, x, rtol=0.1, atol=0.1)
18
19 if x.size(-1) % 2 == 0:
20 inp_dict = {"hidden_states": x.clone()}
21 cast_to_fp8_pipeline(inp_dict)
22 cast_from_fp8_pipeline(inp_dict)
23 assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1)
24
25
26if __name__ == "__main__":

Callers 1

test_fp8_cast.pyFile · 0.85

Calls 8

get_acceleratorFunction · 0.90
cast_to_fp8Function · 0.90
cast_from_fp8Function · 0.90
cast_to_fp8_pipelineFunction · 0.90
cast_from_fp8_pipelineFunction · 0.90
get_current_deviceMethod · 0.45
sizeMethod · 0.45
cloneMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…