()
| 239 | |
| 240 | |
| 241 | def get_lora_checkpoints(): |
| 242 | available_lora_models = {} |
| 243 | allowed_extensions = ["pt", "ckpt", "safetensors"] |
| 244 | candidates = [ |
| 245 | p |
| 246 | for p in os.listdir(cmd_opts.lora_dir) |
| 247 | if p.split(".")[-1] in allowed_extensions |
| 248 | ] |
| 249 | |
| 250 | for filename in candidates: |
| 251 | metadata = {} |
| 252 | name, ext = os.path.splitext(filename) |
| 253 | config_file = os.path.join(cmd_opts.lora_dir, name + ".json") |
| 254 | |
| 255 | if ext == ".safetensors": |
| 256 | metadata = sd_models.read_metadata_from_safetensors( |
| 257 | os.path.join(cmd_opts.lora_dir, filename) |
| 258 | ) |
| 259 | else: |
| 260 | print( |
| 261 | """LoRA {} is not a safetensor. This might cause issues when exporting to TensorRT. |
| 262 | Please ensure that the correct base model is selected when exporting.""".format( |
| 263 | name |
| 264 | ) |
| 265 | ) |
| 266 | |
| 267 | base_model = metadata.get("ss_sd_model_name", "Unknown") |
| 268 | if os.path.exists(config_file): |
| 269 | with open(config_file, "r") as f: |
| 270 | config = json.load(f) |
| 271 | try: |
| 272 | version = SDVersion.from_str(config["sd version"]) |
| 273 | except: |
| 274 | version = SDVersion.Unknown |
| 275 | |
| 276 | else: |
| 277 | version = SDVersion.Unknown |
| 278 | print( |
| 279 | "No config file found for {}. You can generate it in the LoRA tab.".format( |
| 280 | name |
| 281 | ) |
| 282 | ) |
| 283 | |
| 284 | available_lora_models[name] = { |
| 285 | "filename": filename, |
| 286 | "version": version, |
| 287 | "base_model": base_model, |
| 288 | } |
| 289 | return available_lora_models |
| 290 | |
| 291 | |
| 292 | def get_valid_lora_checkpoints(): |
no test coverage detected