| 80 | |
| 81 | @classmethod |
| 82 | def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): |
| 83 | if cls.FP8_DTYPE is None: |
| 84 | raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE") |
| 85 | |
| 86 | orig_dtype = tensor.dtype |
| 87 | orig_shape = tuple(tensor.shape) |
| 88 | |
| 89 | if isinstance(scale, str) and scale == "recalculate": |
| 90 | scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max |
| 91 | if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small |
| 92 | tensor_info = torch.finfo(tensor.dtype) |
| 93 | scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max)) |
| 94 | |
| 95 | if scale is None: |
| 96 | scale = torch.ones((), device=tensor.device, dtype=torch.float32) |
| 97 | if not isinstance(scale, torch.Tensor): |
| 98 | scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) |
| 99 | |
| 100 | if stochastic_rounding > 0: |
| 101 | if inplace_ops: |
| 102 | tensor *= (1.0 / scale).to(tensor.dtype) |
| 103 | else: |
| 104 | tensor = tensor * (1.0 / scale).to(tensor.dtype) |
| 105 | qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding) |
| 106 | else: |
| 107 | qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE) |
| 108 | |
| 109 | params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape) |
| 110 | return qdata, params |
| 111 | |
| 112 | |
| 113 | class TensorCoreMXFP8Layout(_CKMxfp8Layout): |