r""" casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. Args: inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor. scale: scaling factor for fp8 casting. If it is None, then it is computed
(
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
)
| 49 | |
| 50 | |
| 51 | def cast_to_fp8( |
| 52 | inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None |
| 53 | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 54 | r""" |
| 55 | casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. |
| 56 | Args: |
| 57 | inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor. |
| 58 | scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling |
| 59 | is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied. |
| 60 | fp8_format: e4m3 or e5m2 |
| 61 | |
| 62 | Returns: |
| 63 | Tuples: A tuple (fp8_tensor, scale) |
| 64 | """ |
| 65 | |
| 66 | if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]: |
| 67 | raise TypeError("Only float16, bfloat16, and float32 are allowed.") |
| 68 | |
| 69 | fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 |
| 70 | fp8_max = torch.finfo(fp8_type).max |
| 71 | |
| 72 | if inp.numel() == 0: |
| 73 | return inp.to(fp8_type), torch.tensor([1.0], device=inp.device) |
| 74 | else: |
| 75 | if per_channel_scale: |
| 76 | per_channel_max = inp.abs().max(dim=-1).values.float() |
| 77 | per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) |
| 78 | scale = fp8_max / per_channel_max[:, None] |
| 79 | scale_inv = per_channel_max / fp8_max |
| 80 | else: |
| 81 | per_tensor_max = inp.abs().max().float() |
| 82 | per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) |
| 83 | scale = fp8_max / per_tensor_max |
| 84 | scale_inv = 1.0 / scale |
| 85 | |
| 86 | if out is not None: |
| 87 | ret = torch.mul(scale, inp.float(), out=out) |
| 88 | else: |
| 89 | ret = (scale * inp.float()).to(fp8_type) |
| 90 | return ret, torch.unsqueeze(scale_inv, dim=0) |
| 91 | |
| 92 | |
| 93 | def cast_from_fp8( |
searching dependent graphs…