MCPcopy Index your code
hub / github.com/NVIDIA/Stable-Diffusion-WebUI-TensorRT / get_lora_checkpoints

Function get_lora_checkpoints

ui_trt.py:241–289  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

239
240
241def get_lora_checkpoints():
242 available_lora_models = {}
243 allowed_extensions = ["pt", "ckpt", "safetensors"]
244 candidates = [
245 p
246 for p in os.listdir(cmd_opts.lora_dir)
247 if p.split(".")[-1] in allowed_extensions
248 ]
249
250 for filename in candidates:
251 metadata = {}
252 name, ext = os.path.splitext(filename)
253 config_file = os.path.join(cmd_opts.lora_dir, name + ".json")
254
255 if ext == ".safetensors":
256 metadata = sd_models.read_metadata_from_safetensors(
257 os.path.join(cmd_opts.lora_dir, filename)
258 )
259 else:
260 print(
261 """LoRA {} is not a safetensor. This might cause issues when exporting to TensorRT.
262 Please ensure that the correct base model is selected when exporting.""".format(
263 name
264 )
265 )
266
267 base_model = metadata.get("ss_sd_model_name", "Unknown")
268 if os.path.exists(config_file):
269 with open(config_file, "r") as f:
270 config = json.load(f)
271 try:
272 version = SDVersion.from_str(config["sd version"])
273 except:
274 version = SDVersion.Unknown
275
276 else:
277 version = SDVersion.Unknown
278 print(
279 "No config file found for {}. You can generate it in the LoRA tab.".format(
280 name
281 )
282 )
283
284 available_lora_models[name] = {
285 "filename": filename,
286 "version": version,
287 "base_model": base_model,
288 }
289 return available_lora_models
290
291
292def get_valid_lora_checkpoints():

Callers 2

export_lora_to_trtFunction · 0.85

Calls 2

loadMethod · 0.80
from_strMethod · 0.80

Tested by

no test coverage detected