MCPcopy
hub / github.com/hpcaitech/ColossalAI / cast_to_fp8

Function cast_to_fp8

colossalai/quantization/fp8.py:51–90  ·  view source on GitHub ↗

r""" casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. Args: inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor. scale: scaling factor for fp8 casting. If it is None, then it is computed

(
    inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
)

Source from the content-addressed store, hash-verified

49
50
51def cast_to_fp8(
52 inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
53) -> Tuple[torch.Tensor, torch.Tensor]:
54 r"""
55 casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
56 Args:
57 inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor.
58 scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
59 is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
60 fp8_format: e4m3 or e5m2
61
62 Returns:
63 Tuples: A tuple (fp8_tensor, scale)
64 """
65
66 if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
67 raise TypeError("Only float16, bfloat16, and float32 are allowed.")
68
69 fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
70 fp8_max = torch.finfo(fp8_type).max
71
72 if inp.numel() == 0:
73 return inp.to(fp8_type), torch.tensor([1.0], device=inp.device)
74 else:
75 if per_channel_scale:
76 per_channel_max = inp.abs().max(dim=-1).values.float()
77 per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
78 scale = fp8_max / per_channel_max[:, None]
79 scale_inv = per_channel_max / fp8_max
80 else:
81 per_tensor_max = inp.abs().max().float()
82 per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
83 scale = fp8_max / per_tensor_max
84 scale_inv = 1.0 / scale
85
86 if out is not None:
87 ret = torch.mul(scale, inp.float(), out=out)
88 else:
89 ret = (scale * inp.float()).to(fp8_type)
90 return ret, torch.unsqueeze(scale_inv, dim=0)
91
92
93def cast_from_fp8(

Callers 13

test_fp8_castFunction · 0.90
_all_reduce_fp8Function · 0.85
_all_to_all_single_fp8Function · 0.85
_reduce_scatter_fp8Function · 0.85
sum_and_allgatherFunction · 0.85
_all_to_all_fp8Function · 0.85
_all_gather_fp8Function · 0.85
all_gather_fp8_lagacyFunction · 0.85
all_gather_fp8_ringFunction · 0.85
forwardMethod · 0.85

Calls 1

toMethod · 0.45

Tested by 1

test_fp8_castFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…