(model_type, model_path)
| 80 | |
| 81 | |
| 82 | def load_model(model_type, model_path): |
| 83 | global model_dict |
| 84 | model_key = f"{model_type}:{model_path}" |
| 85 | if model_key in model_dict: |
| 86 | return model_dict[model_key] |
| 87 | model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path) |
| 88 | model_manager = ModelManager() |
| 89 | if model_type == "HunyuanDiT": |
| 90 | model_manager.load_models([ |
| 91 | os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"), |
| 92 | os.path.join(model_path, "mt5/pytorch_model.bin"), |
| 93 | os.path.join(model_path, "model/pytorch_model_ema.pt"), |
| 94 | os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"), |
| 95 | ]) |
| 96 | elif model_type == "Kolors": |
| 97 | model_manager.load_models([ |
| 98 | os.path.join(model_path, "text_encoder"), |
| 99 | os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"), |
| 100 | os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"), |
| 101 | ]) |
| 102 | elif model_type == "FLUX": |
| 103 | model_manager.torch_dtype = torch.bfloat16 |
| 104 | file_list = [ |
| 105 | os.path.join(model_path, "text_encoder/model.safetensors"), |
| 106 | os.path.join(model_path, "text_encoder_2"), |
| 107 | ] |
| 108 | for file_name in os.listdir(model_path): |
| 109 | if file_name.endswith(".safetensors"): |
| 110 | file_list.append(os.path.join(model_path, file_name)) |
| 111 | model_manager.load_models(file_list) |
| 112 | else: |
| 113 | model_manager.load_model(model_path) |
| 114 | pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager) |
| 115 | while len(model_dict) + 1 > config["max_num_model_cache"]: |
| 116 | key = next(iter(model_dict.keys())) |
| 117 | model_manager_to_release, _ = model_dict[key] |
| 118 | model_manager_to_release.to("cpu") |
| 119 | del model_dict[key] |
| 120 | torch.cuda.empty_cache() |
| 121 | model_dict[model_key] = model_manager, pipe |
| 122 | return model_manager, pipe |
| 123 | |
| 124 | |
| 125 | model_dict = {} |
no test coverage detected