MCPcopy
hub / github.com/hpcaitech/ColossalAI / load_model

Method load_model

colossalai/booster/booster.py:291–313  ·  view source on GitHub ↗

Load model from checkpoint. Args: model (nn.Module or ModelWrapper): A model boosted by Booster. checkpoint (str): Path to the checkpoint. It must be a local path. It should be a directory path if the checkpoint is sharded. Otherwise, it should be a f

(
        self,
        model: Union[nn.Module, ModelWrapper],
        checkpoint: str,
        strict: bool = True,
        low_cpu_mem_mode: bool = True,
        num_threads: int = 1,
    )

Source from the content-addressed store, hash-verified

289 return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
290
291 def load_model(
292 self,
293 model: Union[nn.Module, ModelWrapper],
294 checkpoint: str,
295 strict: bool = True,
296 low_cpu_mem_mode: bool = True,
297 num_threads: int = 1,
298 ) -> None:
299 """Load model from checkpoint.
300
301 Args:
302 model (nn.Module or ModelWrapper): A model boosted by Booster.
303 checkpoint (str): Path to the checkpoint. It must be a local path.
304 It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
305 strict (bool, optional): whether to strictly enforce that the keys
306 in :attr:`state_dict` match the keys returned by this module's
307 :meth:`~torch.nn.Module.state_dict` function. Defaults to True.
308 low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
309 num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
310 """
311 self.checkpoint_io.load_model(
312 model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
313 )
314
315 def save_model(
316 self,

Callers 15

trainFunction · 0.95
trainFunction · 0.95
trainFunction · 0.95
trainFunction · 0.95
trainFunction · 0.95
trainFunction · 0.95
trainFunction · 0.95
trainFunction · 0.95
benchmark_trainFunction · 0.95
trainFunction · 0.95
boostMethod · 0.95

Calls

no outgoing calls