MCPcopy Index your code
hub / github.com/modelscope/DiffSynth-Studio / ModelDetectorFromSingleFile

Class ModelDetectorFromSingleFile

diffsynth/models/model_manager.py:148–195  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

146
147
148class ModelDetectorFromSingleFile:
149 def __init__(self, model_loader_configs=[]):
150 self.keys_hash_with_shape_dict = {}
151 self.keys_hash_dict = {}
152 for metadata in model_loader_configs:
153 self.add_model_metadata(*metadata)
154
155
156 def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
157 self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
158 if keys_hash is not None:
159 self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
160
161
162 def match(self, file_path="", state_dict={}):
163 if isinstance(file_path, str) and os.path.isdir(file_path):
164 return False
165 if len(state_dict) == 0:
166 state_dict = load_state_dict(file_path)
167 keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
168 if keys_hash_with_shape in self.keys_hash_with_shape_dict:
169 return True
170 keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
171 if keys_hash in self.keys_hash_dict:
172 return True
173 return False
174
175
176 def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
177 if len(state_dict) == 0:
178 state_dict = load_state_dict(file_path)
179
180 # Load models with strict matching
181 keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
182 if keys_hash_with_shape in self.keys_hash_with_shape_dict:
183 model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
184 loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
185 return loaded_model_names, loaded_models
186
187 # Load models without strict matching
188 # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
189 keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
190 if keys_hash in self.keys_hash_dict:
191 model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
192 loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
193 return loaded_model_names, loaded_models
194
195 return loaded_model_names, loaded_models
196
197
198

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected