(data, metadata, original_dtype=np.float32)
| 149 | |
| 150 | |
| 151 | def dequantize_weights(data, metadata, original_dtype=np.float32): |
| 152 | dtype = data.dtype |
| 153 | |
| 154 | if dtype in [np.uint8, np.uint16]: |
| 155 | if not ('scale' in metadata and 'min' in metadata): |
| 156 | raise ValueError( |
| 157 | 'Missing metadata min or scale for dtype %s' % dtype.name) |
| 158 | scale = metadata['scale'] |
| 159 | min_val = metadata['min'] |
| 160 | if original_dtype == np.int32: |
| 161 | return np.round(data * scale + min_val).astype(original_dtype) |
| 162 | else: |
| 163 | return (data * scale + min_val).astype(original_dtype) |
| 164 | elif dtype == np.float16: |
| 165 | if original_dtype != np.float32: |
| 166 | raise ValueError( |
| 167 | 'Invalid data dtype %r\n' |
| 168 | 'float16 quantization only supports float32 dtype' % data.dtype) |
| 169 | return data.astype(original_dtype) |
| 170 | else: |
| 171 | raise ValueError( |
| 172 | 'Invalid dtype %s for dequantization\n' |
| 173 | 'Supported dtypes are uint8, uint16, float16' % dtype.name) |
| 174 | |
| 175 | def _get_affine_quantization_range(min_val, max_val, quantization_dtype): |
| 176 | """Computes quantization range to ensure that zero is represented if covered. |
nothing calls this directly
no test coverage detected
searching dependent graphs…