MCPcopy Index your code
hub / github.com/modelscope/DiffSynth-Studio / load_model

Function load_model

apps/gradio/DiffSynth_Studio.py:82–122  ·  view source on GitHub ↗
(model_type, model_path)

Source from the content-addressed store, hash-verified

80
81
82def load_model(model_type, model_path):
83 global model_dict
84 model_key = f"{model_type}:{model_path}"
85 if model_key in model_dict:
86 return model_dict[model_key]
87 model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
88 model_manager = ModelManager()
89 if model_type == "HunyuanDiT":
90 model_manager.load_models([
91 os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
92 os.path.join(model_path, "mt5/pytorch_model.bin"),
93 os.path.join(model_path, "model/pytorch_model_ema.pt"),
94 os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
95 ])
96 elif model_type == "Kolors":
97 model_manager.load_models([
98 os.path.join(model_path, "text_encoder"),
99 os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
100 os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
101 ])
102 elif model_type == "FLUX":
103 model_manager.torch_dtype = torch.bfloat16
104 file_list = [
105 os.path.join(model_path, "text_encoder/model.safetensors"),
106 os.path.join(model_path, "text_encoder_2"),
107 ]
108 for file_name in os.listdir(model_path):
109 if file_name.endswith(".safetensors"):
110 file_list.append(os.path.join(model_path, file_name))
111 model_manager.load_models(file_list)
112 else:
113 model_manager.load_model(model_path)
114 pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
115 while len(model_dict) + 1 > config["max_num_model_cache"]:
116 key = next(iter(model_dict.keys()))
117 model_manager_to_release, _ = model_dict[key]
118 model_manager_to_release.to("cpu")
119 del model_dict[key]
120 torch.cuda.empty_cache()
121 model_dict[model_key] = model_manager, pipe
122 return model_manager, pipe
123
124
125model_dict = {}

Callers 2

generate_imageFunction · 0.70

Calls 5

load_modelsMethod · 0.95
load_modelMethod · 0.95
ModelManagerClass · 0.90
from_model_managerMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected