(model_type, model_path)
| 78 | |
| 79 | |
| 80 | def load_model(model_type, model_path): |
| 81 | model_manager = ModelManager() |
| 82 | if model_type == "HunyuanDiT": |
| 83 | model_manager.load_models([ |
| 84 | os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"), |
| 85 | os.path.join(model_path, "mt5/pytorch_model.bin"), |
| 86 | os.path.join(model_path, "model/pytorch_model_ema.pt"), |
| 87 | os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"), |
| 88 | ]) |
| 89 | elif model_type == "Kolors": |
| 90 | model_manager.load_models([ |
| 91 | os.path.join(model_path, "text_encoder"), |
| 92 | os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"), |
| 93 | os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"), |
| 94 | ]) |
| 95 | elif model_type == "FLUX": |
| 96 | model_manager.torch_dtype = torch.bfloat16 |
| 97 | file_list = [ |
| 98 | os.path.join(model_path, "text_encoder/model.safetensors"), |
| 99 | os.path.join(model_path, "text_encoder_2"), |
| 100 | ] |
| 101 | for file_name in os.listdir(model_path): |
| 102 | if file_name.endswith(".safetensors"): |
| 103 | file_list.append(os.path.join(model_path, file_name)) |
| 104 | model_manager.load_models(file_list) |
| 105 | else: |
| 106 | model_manager.load_model(model_path) |
| 107 | pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager) |
| 108 | st.session_state.loaded_model_path = model_path |
| 109 | st.session_state.model_manager = model_manager |
| 110 | st.session_state.pipeline = pipeline |
| 111 | return model_manager, pipeline |
| 112 | |
| 113 | |
| 114 | def use_output_image_as_input(update=True): |
no test coverage detected