(q_bits, is_symmetric_quant, max, min, absmax, scales=None, zero_points=None)
| 48 | |
| 49 | |
| 50 | def get_scale_zero_point(q_bits, is_symmetric_quant, max, min, absmax, scales=None, zero_points=None): |
| 51 | |
| 52 | q_range, q_max, q_min = get_q_props(q_bits) |
| 53 | |
| 54 | if is_symmetric_quant: |
| 55 | scale = torch.empty_like(absmax) |
| 56 | for i, x in enumerate(absmax): |
| 57 | scale[i] = torch.ones_like(x) if x == 0 else q_range / (2 * x) |
| 58 | zero_point = torch.zeros(scale.shape, dtype=torch.float32, device=get_accelerator().device_name()) |
| 59 | else: |
| 60 | scale = torch.empty_like(max) |
| 61 | for i, x in enumerate(max): |
| 62 | scale[i] = torch.ones_like(x) if max[i] == min[i] else q_range / (max[i] - min[i]) |
| 63 | zero_point = q_min - (min * scale) |
| 64 | |
| 65 | return scale, zero_point |
| 66 | |
| 67 | |
| 68 | def int4x2to2xint4(int4X2tensor): |
no test coverage detected
searching dependent graphs…