This class is for inference model loading, most codes are copied from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io.HybridParallelCheckpointIO. Origin HybridParallelCheckpointIO contains some codes about MixPrecision-Training, so we remove them and build a relatively clean class sp
| 19 | |
| 20 | |
| 21 | class InferCheckpoint_io(GeneralCheckpointIO): |
| 22 | """ |
| 23 | This class is for inference model loading, most codes are copied from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io.HybridParallelCheckpointIO. |
| 24 | Origin HybridParallelCheckpointIO contains some codes about MixPrecision-Training, so we remove them and build a relatively clean class specifically for Inference. |
| 25 | """ |
| 26 | |
| 27 | def __init__( |
| 28 | self, |
| 29 | verbose: bool = True, |
| 30 | ) -> None: |
| 31 | super().__init__() |
| 32 | self.verbose = verbose |
| 33 | self.coordinator = DistCoordinator() |
| 34 | |
| 35 | def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False): |
| 36 | """ |
| 37 | Load sharded model with the given path to index file of checkpoint folder. |
| 38 | |
| 39 | Args: |
| 40 | model (nn.Module): The model to be loaded. |
| 41 | checkpoint_index_file (str): Path to the index file of checkpointing folder. |
| 42 | strict (bool, optional): For name matching during loading state_dict. Defaults to False. |
| 43 | This argument should be manually set to False since params on same device might be stored in different files. |
| 44 | """ |
| 45 | assert isinstance(model, ModelWrapper), "Please boost the model before loading!" |
| 46 | model = model.unwrap() |
| 47 | |
| 48 | # Check whether the checkpoint uses safetensors. |
| 49 | use_safetensors = False |
| 50 | if "safetensors" in checkpoint_index_file.name: |
| 51 | use_safetensors = True |
| 52 | |
| 53 | if use_safetensors and not is_safetensors_available(): |
| 54 | raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") |
| 55 | |
| 56 | # Read checkpoint index file. |
| 57 | ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) |
| 58 | ckpt_root_path = ckpt_index_file.root_path |
| 59 | weight_map = ckpt_index_file.weight_map |
| 60 | strict = False |
| 61 | |
| 62 | # Load params & buffers to model. |
| 63 | # Keep a record of loaded files so that file will not be repeatedly loaded. |
| 64 | loaded_file = set() |
| 65 | |
| 66 | missing_keys = [] |
| 67 | missing_file_keys = [] |
| 68 | |
| 69 | def _load(name: str): |
| 70 | if name not in weight_map: |
| 71 | missing_file_keys.append(name) |
| 72 | return |
| 73 | filename = weight_map[name] |
| 74 | |
| 75 | # If this param/buffer has been loaded before, directly return. |
| 76 | if filename in loaded_file: |
| 77 | return |
| 78 |
no outgoing calls
no test coverage detected
searching dependent graphs…