(cls, fp_out, input_q, scale, group_size, q_mantisa_bits, q_exponent_bits)
| 54 | |
| 55 | @classmethod |
| 56 | def dequantize(cls, fp_out, input_q, scale, group_size, q_mantisa_bits, q_exponent_bits): |
| 57 | orig_shape = fp_out.shape |
| 58 | orig_dtype = fp_out.dtype |
| 59 | dequant_out = torch.ops.hpu.cast_from_fp8(input_q, (1.0 / scale), orig_dtype).view(orig_shape) |
| 60 | fp_out.copy_(dequant_out) |
| 61 | return fp_out |
| 62 | |
| 63 | @classmethod |
| 64 | def quantize(cls, out, val, scale, group_size, stochastic_rounding, q_bits, q_mantisa_bits): |