| 1480 | # Adapted from `transformers`. |
| 1481 | @wraps(torch.nn.Module.to) |
| 1482 | def to(self, *args, **kwargs): |
| 1483 | from ..hooks.group_offloading import _is_group_offload_enabled |
| 1484 | |
| 1485 | device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs |
| 1486 | dtype_present_in_args = "dtype" in kwargs |
| 1487 | |
| 1488 | # Try converting arguments to torch.device in case they are passed as strings |
| 1489 | for arg in args: |
| 1490 | if not isinstance(arg, str): |
| 1491 | continue |
| 1492 | try: |
| 1493 | torch.device(arg) |
| 1494 | device_arg_or_kwarg_present = True |
| 1495 | except RuntimeError: |
| 1496 | pass |
| 1497 | |
| 1498 | if not dtype_present_in_args: |
| 1499 | for arg in args: |
| 1500 | if isinstance(arg, torch.dtype): |
| 1501 | dtype_present_in_args = True |
| 1502 | break |
| 1503 | |
| 1504 | if getattr(self, "is_quantized", False): |
| 1505 | if dtype_present_in_args: |
| 1506 | raise ValueError( |
| 1507 | "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please " |
| 1508 | "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`" |
| 1509 | ) |
| 1510 | |
| 1511 | if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: |
| 1512 | if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"): |
| 1513 | raise ValueError( |
| 1514 | "Calling `to()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. " |
| 1515 | f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0." |
| 1516 | ) |
| 1517 | elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"): |
| 1518 | raise ValueError( |
| 1519 | "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " |
| 1520 | f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." |
| 1521 | ) |
| 1522 | if _is_group_offload_enabled(self) and device_arg_or_kwarg_present: |
| 1523 | logger.warning( |
| 1524 | f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported." |
| 1525 | ) |
| 1526 | return self |
| 1527 | |
| 1528 | return super().to(*args, **kwargs) |
| 1529 | |
| 1530 | # Taken from `transformers`. |
| 1531 | def half(self, *args): |