| 532 | |
| 533 | |
| 534 | class GGUFParameter(torch.nn.Parameter): |
| 535 | def __new__(cls, data, requires_grad=False, quant_type=None): |
| 536 | data = data if data is not None else torch.empty(0) |
| 537 | self = torch.Tensor._make_subclass(cls, data, requires_grad) |
| 538 | self.quant_type = quant_type |
| 539 | block_size, type_size = GGML_QUANT_SIZES[quant_type] |
| 540 | self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size) |
| 541 | |
| 542 | return self |
| 543 | |
| 544 | def as_tensor(self): |
| 545 | return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad) |
| 546 | |
| 547 | @staticmethod |
| 548 | def _extract_quant_type(args): |
| 549 | # When converting from original format checkpoints we often use splits, cats etc on tensors |
| 550 | # this method ensures that the returned tensor type from those operations remains GGUFParameter |
| 551 | # so that we preserve quant_type information |
| 552 | for arg in args: |
| 553 | if isinstance(arg, list) and isinstance(arg[0], GGUFParameter): |
| 554 | return arg[0].quant_type |
| 555 | if isinstance(arg, GGUFParameter): |
| 556 | return arg.quant_type |
| 557 | return None |
| 558 | |
| 559 | @classmethod |
| 560 | def __torch_function__(cls, func, types, args=(), kwargs=None): |
| 561 | if kwargs is None: |
| 562 | kwargs = {} |
| 563 | |
| 564 | result = super().__torch_function__(func, types, args, kwargs) |
| 565 | |
| 566 | if isinstance(result, torch.Tensor): |
| 567 | quant_type = cls._extract_quant_type(args) |
| 568 | return cls(result, quant_type=quant_type) |
| 569 | # Handle tuples and lists |
| 570 | elif type(result) in (list, tuple): |
| 571 | # Preserve the original type (tuple or list) |
| 572 | quant_type = cls._extract_quant_type(args) |
| 573 | wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result] |
| 574 | return type(result)(wrapped) |
| 575 | else: |
| 576 | return result |
| 577 | |
| 578 | |
| 579 | class GGUFLinear(nn.Linear): |
no outgoing calls
searching dependent graphs…