(tensor, dtype=None, dequant_dtype=None)
| 13 | return not is_torch_compatible(tensor) |
| 14 | |
| 15 | def dequantize_tensor(tensor, dtype=None, dequant_dtype=None): |
| 16 | qtype = getattr(tensor, "tensor_type", None) |
| 17 | oshape = getattr(tensor, "tensor_shape", tensor.shape) |
| 18 | |
| 19 | if qtype in TORCH_COMPATIBLE_QTYPES: |
| 20 | return tensor.to(dtype) |
| 21 | elif qtype in dequantize_functions: |
| 22 | dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype |
| 23 | return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype) |
| 24 | else: |
| 25 | # this is incredibly slow |
| 26 | tqdm.write(f"Falling back to numpy dequant for qtype: {getattr(qtype, 'name', repr(qtype))}") |
| 27 | new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) |
| 28 | return torch.from_numpy(new).to(tensor.device, dtype=dtype) |
| 29 | |
| 30 | def dequantize(data, qtype, oshape, dtype=None): |
| 31 | """ |
no test coverage detected