MCPcopy
hub / github.com/hpcaitech/ColossalAI / InferCheckpoint_io

Class InferCheckpoint_io

colossalai/inference/core/plugin.py:21–140  ·  view source on GitHub ↗

This class is for inference model loading, most codes are copied from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io.HybridParallelCheckpointIO. Origin HybridParallelCheckpointIO contains some codes about MixPrecision-Training, so we remove them and build a relatively clean class sp

Source from the content-addressed store, hash-verified

19
20
21class InferCheckpoint_io(GeneralCheckpointIO):
22 """
23 This class is for inference model loading, most codes are copied from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io.HybridParallelCheckpointIO.
24 Origin HybridParallelCheckpointIO contains some codes about MixPrecision-Training, so we remove them and build a relatively clean class specifically for Inference.
25 """
26
27 def __init__(
28 self,
29 verbose: bool = True,
30 ) -> None:
31 super().__init__()
32 self.verbose = verbose
33 self.coordinator = DistCoordinator()
34
35 def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
36 """
37 Load sharded model with the given path to index file of checkpoint folder.
38
39 Args:
40 model (nn.Module): The model to be loaded.
41 checkpoint_index_file (str): Path to the index file of checkpointing folder.
42 strict (bool, optional): For name matching during loading state_dict. Defaults to False.
43 This argument should be manually set to False since params on same device might be stored in different files.
44 """
45 assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
46 model = model.unwrap()
47
48 # Check whether the checkpoint uses safetensors.
49 use_safetensors = False
50 if "safetensors" in checkpoint_index_file.name:
51 use_safetensors = True
52
53 if use_safetensors and not is_safetensors_available():
54 raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
55
56 # Read checkpoint index file.
57 ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
58 ckpt_root_path = ckpt_index_file.root_path
59 weight_map = ckpt_index_file.weight_map
60 strict = False
61
62 # Load params & buffers to model.
63 # Keep a record of loaded files so that file will not be repeatedly loaded.
64 loaded_file = set()
65
66 missing_keys = []
67 missing_file_keys = []
68
69 def _load(name: str):
70 if name not in weight_map:
71 missing_file_keys.append(name)
72 return
73 filename = weight_map[name]
74
75 # If this param/buffer has been loaded before, directly return.
76 if filename in loaded_file:
77 return
78

Callers 2

init_modelMethod · 0.90
_init_modelMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…