MCPcopy
hub / github.com/Comfy-Org/ComfyUI / quantize

Method quantize

comfy/quant_ops.py:82–110  ·  view source on GitHub ↗
(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False)

Source from the content-addressed store, hash-verified

80
81 @classmethod
82 def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
83 if cls.FP8_DTYPE is None:
84 raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE")
85
86 orig_dtype = tensor.dtype
87 orig_shape = tuple(tensor.shape)
88
89 if isinstance(scale, str) and scale == "recalculate":
90 scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max
91 if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
92 tensor_info = torch.finfo(tensor.dtype)
93 scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
94
95 if scale is None:
96 scale = torch.ones((), device=tensor.device, dtype=torch.float32)
97 if not isinstance(scale, torch.Tensor):
98 scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
99
100 if stochastic_rounding > 0:
101 if inplace_ops:
102 tensor *= (1.0 / scale).to(tensor.dtype)
103 else:
104 tensor = tensor * (1.0 / scale).to(tensor.dtype)
105 qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding)
106 else:
107 qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE)
108
109 params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
110 return qdata, params
111
112
113class TensorCoreMXFP8Layout(_CKMxfp8Layout):

Callers

nothing calls this directly

Calls 2

floatMethod · 0.80
toMethod · 0.45

Tested by

no test coverage detected