load un-quantized pretrained model to cpu
(
cls,
pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig,
max_memory: Optional[dict] = None,
trust_remote_code: bool = False,
torch_dtype: torch.dtype = torch.float16,
**model_init_kwargs
)
| 29 | |
| 30 | @classmethod |
| 31 | def from_pretrained( |
| 32 | cls, |
| 33 | pretrained_model_name_or_path: str, |
| 34 | quantize_config: BaseQuantizeConfig, |
| 35 | max_memory: Optional[dict] = None, |
| 36 | trust_remote_code: bool = False, |
| 37 | torch_dtype: torch.dtype = torch.float16, |
| 38 | **model_init_kwargs |
| 39 | ): |
| 40 | """load un-quantized pretrained model to cpu""" |
| 41 | |
| 42 | if not torch.cuda.is_available(): |
| 43 | raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.") |
| 44 | |
| 45 | def skip(*args, **kwargs): |
| 46 | pass |
| 47 | |
| 48 | torch.nn.init.kaiming_uniform_ = skip |
| 49 | torch.nn.init.uniform_ = skip |
| 50 | torch.nn.init.normal_ = skip |
| 51 | |
| 52 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) |
| 53 | |
| 54 | # enforce some values despite user specified |
| 55 | model_init_kwargs["torch_dtype"] = torch_dtype |
| 56 | model_init_kwargs["trust_remote_code"] = trust_remote_code |
| 57 | if max_memory: |
| 58 | if "disk" in max_memory: |
| 59 | raise NotImplementedError("disk offload not support yet.") |
| 60 | with accelerate.init_empty_weights(): |
| 61 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) |
| 62 | model.tie_weights() |
| 63 | |
| 64 | max_memory = accelerate.utils.get_balanced_memory( |
| 65 | model, |
| 66 | max_memory=max_memory, |
| 67 | no_split_module_classes=[cls.layer_type], |
| 68 | dtype=model_init_kwargs["torch_dtype"], |
| 69 | low_zero=False |
| 70 | ) |
| 71 | model_init_kwargs["device_map"] = accelerate.infer_auto_device_map( |
| 72 | model, |
| 73 | max_memory=max_memory, |
| 74 | no_split_module_classes=[cls.layer_type], |
| 75 | dtype=model_init_kwargs["torch_dtype"] |
| 76 | ) |
| 77 | model_init_kwargs["low_cpu_mem_usage"] = True |
| 78 | |
| 79 | del model |
| 80 | else: |
| 81 | model_init_kwargs["device_map"] = None |
| 82 | model_init_kwargs["low_cpu_mem_usage"] = False |
| 83 | |
| 84 | torch.cuda.empty_cache() |
| 85 | |
| 86 | model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs) |
| 87 | model_config = model.config.to_dict() |
| 88 | seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] |
no test coverage detected