MCPcopy
hub / github.com/hpcaitech/ColossalAI / init_model

Method init_model

colossalai/inference/core/llm_engine.py:107–210  ·  view source on GitHub ↗

Shard model or/and Load weight Args: model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. model_policy (Policy): the policy to replace the model. model_inference_config: the configuration for modeling initia

(
        self,
        model_or_path: Union[nn.Module, str],
        model_policy: Union[Policy, Type[Policy]] = None,
        model_shard_infer_config: ModelShardInferenceConfig = None,
    )

Source from the content-addressed store, hash-verified

105 self._verify_args()
106
107 def init_model(
108 self,
109 model_or_path: Union[nn.Module, str],
110 model_policy: Union[Policy, Type[Policy]] = None,
111 model_shard_infer_config: ModelShardInferenceConfig = None,
112 ):
113 """
114 Shard model or/and Load weight
115
116 Args:
117 model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
118 model_policy (Policy): the policy to replace the model.
119 model_inference_config: the configuration for modeling initialization when inference.
120 model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
121 """
122 pretrained_path = None
123 if isinstance(model_or_path, str):
124 import colossalai.interface.pretrained as pretrained_utils
125
126 try:
127 hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
128 arch = getattr(hf_config, "architectures")[0]
129 if arch in _supported_models.keys():
130 if arch == "BaichuanForCausalLM":
131 self.logger.warning(
132 "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
133 )
134 ctx = LazyInitContext(default_device="cuda")
135 with ctx:
136 model = _supported_models[arch].from_pretrained(
137 model_or_path, trust_remote_code=True, torch_dtype=self.dtype
138 )
139 pretrained_path = pretrained_utils.get_pretrained_path(model)
140 else:
141 # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
142 raise ValueError(f"Model {arch} is not supported.")
143
144 except Exception as e:
145 self.logger.error(
146 f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
147 )
148 else:
149 model = model_or_path
150
151 self.model_config = model.config
152
153 torch.cuda.empty_cache()
154 init_gpu_memory = torch.cuda.mem_get_info()[0]
155
156 self.device = get_accelerator().get_current_device()
157 if self.verbose:
158 self.logger.info(f"the device is {self.device}")
159
160 model = model.to(self.dtype).eval()
161
162 if self.verbose:
163 self.logger.info(
164 f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"

Callers 1

__init__Method · 0.95

Calls 15

get_group_along_axisMethod · 0.95
LazyInitContextClass · 0.90
get_acceleratorFunction · 0.90
get_model_sizeFunction · 0.90
ProcessGroupMeshClass · 0.90
ModelWrapperClass · 0.90
InferCheckpoint_ioClass · 0.90
has_index_fileFunction · 0.90
get_rankMethod · 0.80
from_pretrainedMethod · 0.45
warningMethod · 0.45
errorMethod · 0.45

Tested by

no test coverage detected