Shard model or/and Load weight Args: model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. model_policy (Policy): the policy to replace the model. model_inference_config: the configuration for modeling initia
(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
)
| 105 | self._verify_args() |
| 106 | |
| 107 | def init_model( |
| 108 | self, |
| 109 | model_or_path: Union[nn.Module, str], |
| 110 | model_policy: Union[Policy, Type[Policy]] = None, |
| 111 | model_shard_infer_config: ModelShardInferenceConfig = None, |
| 112 | ): |
| 113 | """ |
| 114 | Shard model or/and Load weight |
| 115 | |
| 116 | Args: |
| 117 | model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. |
| 118 | model_policy (Policy): the policy to replace the model. |
| 119 | model_inference_config: the configuration for modeling initialization when inference. |
| 120 | model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. |
| 121 | """ |
| 122 | pretrained_path = None |
| 123 | if isinstance(model_or_path, str): |
| 124 | import colossalai.interface.pretrained as pretrained_utils |
| 125 | |
| 126 | try: |
| 127 | hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) |
| 128 | arch = getattr(hf_config, "architectures")[0] |
| 129 | if arch in _supported_models.keys(): |
| 130 | if arch == "BaichuanForCausalLM": |
| 131 | self.logger.warning( |
| 132 | "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" |
| 133 | ) |
| 134 | ctx = LazyInitContext(default_device="cuda") |
| 135 | with ctx: |
| 136 | model = _supported_models[arch].from_pretrained( |
| 137 | model_or_path, trust_remote_code=True, torch_dtype=self.dtype |
| 138 | ) |
| 139 | pretrained_path = pretrained_utils.get_pretrained_path(model) |
| 140 | else: |
| 141 | # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate |
| 142 | raise ValueError(f"Model {arch} is not supported.") |
| 143 | |
| 144 | except Exception as e: |
| 145 | self.logger.error( |
| 146 | f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" |
| 147 | ) |
| 148 | else: |
| 149 | model = model_or_path |
| 150 | |
| 151 | self.model_config = model.config |
| 152 | |
| 153 | torch.cuda.empty_cache() |
| 154 | init_gpu_memory = torch.cuda.mem_get_info()[0] |
| 155 | |
| 156 | self.device = get_accelerator().get_current_device() |
| 157 | if self.verbose: |
| 158 | self.logger.info(f"the device is {self.device}") |
| 159 | |
| 160 | model = model.to(self.dtype).eval() |
| 161 | |
| 162 | if self.verbose: |
| 163 | self.logger.info( |
| 164 | f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" |
no test coverage detected