(x:Tensor, scale:Tensor, zero_point:Tensor|int, axis=1, block_size=0)
| 507 | def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtype.min, dtype.max).cast(dtype) |
| 508 | |
| 509 | def _prepare_quantize(x:Tensor, scale:Tensor, zero_point:Tensor|int, axis=1, block_size=0): |
| 510 | if axis < 0: axis += x.ndim |
| 511 | # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_quantize_linear.py#L31 |
| 512 | def reshape(val:Tensor): |
| 513 | if val.numel() == 1: return val |
| 514 | if block_size == 0: return val.reshape([val.shape[0] if dim == axis else 1 for dim in range(x.ndim)]) |
| 515 | return val.repeat_interleave(block_size, axis) |
| 516 | return (reshape(scale), reshape(zero_point) if isinstance(zero_point, Tensor) else zero_point) |
| 517 | |
| 518 | def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts): |
| 519 | adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)] |
no test coverage detected
searching dependent graphs…