MCPcopy Index your code
hub / github.com/modelscope/DiffSynth-Studio / load_model

Function load_model

apps/streamlit/pages/1_Image_Creator.py:80–111  ·  view source on GitHub ↗
(model_type, model_path)

Source from the content-addressed store, hash-verified

78
79
80def load_model(model_type, model_path):
81 model_manager = ModelManager()
82 if model_type == "HunyuanDiT":
83 model_manager.load_models([
84 os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
85 os.path.join(model_path, "mt5/pytorch_model.bin"),
86 os.path.join(model_path, "model/pytorch_model_ema.pt"),
87 os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
88 ])
89 elif model_type == "Kolors":
90 model_manager.load_models([
91 os.path.join(model_path, "text_encoder"),
92 os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
93 os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
94 ])
95 elif model_type == "FLUX":
96 model_manager.torch_dtype = torch.bfloat16
97 file_list = [
98 os.path.join(model_path, "text_encoder/model.safetensors"),
99 os.path.join(model_path, "text_encoder_2"),
100 ]
101 for file_name in os.listdir(model_path):
102 if file_name.endswith(".safetensors"):
103 file_list.append(os.path.join(model_path, file_name))
104 model_manager.load_models(file_list)
105 else:
106 model_manager.load_model(model_path)
107 pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
108 st.session_state.loaded_model_path = model_path
109 st.session_state.model_manager = model_manager
110 st.session_state.pipeline = pipeline
111 return model_manager, pipeline
112
113
114def use_output_image_as_input(update=True):

Callers 1

1_Image_Creator.pyFile · 0.70

Calls 4

load_modelsMethod · 0.95
load_modelMethod · 0.95
ModelManagerClass · 0.90
from_model_managerMethod · 0.45

Tested by

no test coverage detected