(self, *args, **kwargs)
| 1453 | # Adapted from `transformers`. |
| 1454 | @wraps(torch.nn.Module.cuda) |
| 1455 | def cuda(self, *args, **kwargs): |
| 1456 | from ..hooks.group_offloading import _is_group_offload_enabled |
| 1457 | |
| 1458 | # Checks if the model has been loaded in 4-bit or 8-bit with BNB |
| 1459 | if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: |
| 1460 | if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"): |
| 1461 | raise ValueError( |
| 1462 | "Calling `cuda()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. " |
| 1463 | f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0." |
| 1464 | ) |
| 1465 | elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"): |
| 1466 | raise ValueError( |
| 1467 | "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " |
| 1468 | f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." |
| 1469 | ) |
| 1470 | |
| 1471 | # Checks if group offloading is enabled |
| 1472 | if _is_group_offload_enabled(self): |
| 1473 | logger.warning( |
| 1474 | f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.cuda()` is not supported." |
| 1475 | ) |
| 1476 | return self |
| 1477 | |
| 1478 | return super().cuda(*args, **kwargs) |
| 1479 | |
| 1480 | # Adapted from `transformers`. |
| 1481 | @wraps(torch.nn.Module.to) |
no test coverage detected