| 146 | |
| 147 | |
| 148 | class ModelDetectorFromSingleFile: |
| 149 | def __init__(self, model_loader_configs=[]): |
| 150 | self.keys_hash_with_shape_dict = {} |
| 151 | self.keys_hash_dict = {} |
| 152 | for metadata in model_loader_configs: |
| 153 | self.add_model_metadata(*metadata) |
| 154 | |
| 155 | |
| 156 | def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource): |
| 157 | self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource) |
| 158 | if keys_hash is not None: |
| 159 | self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource) |
| 160 | |
| 161 | |
| 162 | def match(self, file_path="", state_dict={}): |
| 163 | if isinstance(file_path, str) and os.path.isdir(file_path): |
| 164 | return False |
| 165 | if len(state_dict) == 0: |
| 166 | state_dict = load_state_dict(file_path) |
| 167 | keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) |
| 168 | if keys_hash_with_shape in self.keys_hash_with_shape_dict: |
| 169 | return True |
| 170 | keys_hash = hash_state_dict_keys(state_dict, with_shape=False) |
| 171 | if keys_hash in self.keys_hash_dict: |
| 172 | return True |
| 173 | return False |
| 174 | |
| 175 | |
| 176 | def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs): |
| 177 | if len(state_dict) == 0: |
| 178 | state_dict = load_state_dict(file_path) |
| 179 | |
| 180 | # Load models with strict matching |
| 181 | keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) |
| 182 | if keys_hash_with_shape in self.keys_hash_with_shape_dict: |
| 183 | model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape] |
| 184 | loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device) |
| 185 | return loaded_model_names, loaded_models |
| 186 | |
| 187 | # Load models without strict matching |
| 188 | # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture) |
| 189 | keys_hash = hash_state_dict_keys(state_dict, with_shape=False) |
| 190 | if keys_hash in self.keys_hash_dict: |
| 191 | model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash] |
| 192 | loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device) |
| 193 | return loaded_model_names, loaded_models |
| 194 | |
| 195 | return loaded_model_names, loaded_models |
| 196 | |
| 197 | |
| 198 | |