(self, prompt, t5=None, force_offload=True, model_to_offload=None, use_disk_cache=False, device="gpu")
| 475 | DESCRIPTION = "Encodes text prompt into text embedding." |
| 476 | |
| 477 | def process(self, prompt, t5=None, force_offload=True, model_to_offload=None, use_disk_cache=False, device="gpu"): |
| 478 | # Unified cache logic: use a single cache file per unique prompt |
| 479 | encoded = None |
| 480 | echoshot = True if "[1]" in prompt else False |
| 481 | if use_disk_cache: |
| 482 | cache_dir = os.path.join(script_directory, 'text_embed_cache') |
| 483 | os.makedirs(cache_dir, exist_ok=True) |
| 484 | def get_cache_path(prompt): |
| 485 | cache_key = prompt.strip() |
| 486 | cache_hash = hashlib.sha256(cache_key.encode('utf-8')).hexdigest() |
| 487 | return os.path.join(cache_dir, f"{cache_hash}.pt") |
| 488 | cache_path = get_cache_path(prompt) |
| 489 | if os.path.exists(cache_path): |
| 490 | try: |
| 491 | log.info(f"Loading prompt embeds from cache: {cache_path}") |
| 492 | encoded = torch.load(cache_path) |
| 493 | except Exception as e: |
| 494 | log.warning(f"Failed to load cache: {e}, will re-encode.") |
| 495 | |
| 496 | if t5 is None and encoded is None: |
| 497 | raise ValueError("No cached text embeds found for prompts, please provide a T5 encoder.") |
| 498 | |
| 499 | if encoded is None: |
| 500 | try: |
| 501 | if model_to_offload is not None and device == "gpu": |
| 502 | log.info(f"Moving video model to {offload_device}") |
| 503 | model_to_offload.model.to(offload_device) |
| 504 | mm.soft_empty_cache() |
| 505 | except Exception: |
| 506 | pass |
| 507 | |
| 508 | encoder = t5["model"] |
| 509 | dtype = t5["dtype"] |
| 510 | |
| 511 | if device == "gpu": |
| 512 | device_to = mm.get_torch_device() |
| 513 | else: |
| 514 | device_to = torch.device("cpu") |
| 515 | |
| 516 | if encoder.quantization == "fp8_e4m3fn": |
| 517 | cast_dtype = torch.float8_e4m3fn |
| 518 | else: |
| 519 | cast_dtype = encoder.dtype |
| 520 | params_to_keep = {'norm', 'pos_embedding', 'token_embedding'} |
| 521 | for name, param in encoder.model.named_parameters(): |
| 522 | dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype |
| 523 | value = encoder.state_dict[name] if hasattr(encoder, 'state_dict') else encoder.model.state_dict()[name] |
| 524 | set_module_tensor_to_device(encoder.model, name, device=device_to, dtype=dtype_to_use, value=value) |
| 525 | if hasattr(encoder, 'state_dict'): |
| 526 | del encoder.state_dict |
| 527 | mm.soft_empty_cache() |
| 528 | gc.collect() |
| 529 | with torch.autocast(device_type=mm.get_autocast_device(device_to), dtype=encoder.dtype, enabled=encoder.quantization != 'disabled'): |
| 530 | encoded = encoder([prompt], device_to) |
| 531 | |
| 532 | if force_offload: |
| 533 | encoder.model.to(offload_device) |
| 534 | mm.soft_empty_cache() |
nothing calls this directly
no test coverage detected