| 1960 | return (vae,) |
| 1961 | |
| 1962 | class LoadWanVideoT5TextEncoder: |
| 1963 | @classmethod |
| 1964 | def INPUT_TYPES(s): |
| 1965 | return { |
| 1966 | "required": { |
| 1967 | "model_name": (folder_paths.get_filename_list("text_encoders"), {"tooltip": "These models are loaded from 'ComfyUI/models/text_encoders'"}), |
| 1968 | "precision": (["fp32", "bf16"], |
| 1969 | {"default": "bf16"} |
| 1970 | ), |
| 1971 | }, |
| 1972 | "optional": { |
| 1973 | "load_device": (["main_device", "offload_device"], {"default": "offload_device"}), |
| 1974 | "quantization": (['disabled', 'fp8_e4m3fn'], {"default": 'disabled', "tooltip": "optional quantization method"}), |
| 1975 | } |
| 1976 | } |
| 1977 | |
| 1978 | RETURN_TYPES = ("WANTEXTENCODER",) |
| 1979 | RETURN_NAMES = ("wan_t5_model", ) |
| 1980 | FUNCTION = "loadmodel" |
| 1981 | CATEGORY = "WanVideoWrapper" |
| 1982 | DESCRIPTION = "Loads Wan text_encoder model from 'ComfyUI/models/LLM'" |
| 1983 | |
| 1984 | def loadmodel(self, model_name, precision, load_device="offload_device", quantization="disabled"): |
| 1985 | text_encoder_load_device = device if load_device == "main_device" else offload_device |
| 1986 | |
| 1987 | tokenizer_path = os.path.join(script_directory, "configs", "T5_tokenizer") |
| 1988 | |
| 1989 | dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] |
| 1990 | |
| 1991 | model_path = folder_paths.get_full_path_or_raise("text_encoders", model_name) |
| 1992 | sd = load_torch_file(model_path, safe_load=True) |
| 1993 | |
| 1994 | if quantization == "disabled": |
| 1995 | for k, v in sd.items(): |
| 1996 | if isinstance(v, torch.Tensor): |
| 1997 | if v.dtype == torch.float8_e4m3fn: |
| 1998 | quantization = "fp8_e4m3fn" |
| 1999 | break |
| 2000 | |
| 2001 | if "token_embedding.weight" not in sd and "shared.weight" not in sd: |
| 2002 | raise ValueError("Invalid T5 text encoder model, this node expects the 'umt5-xxl' model") |
| 2003 | if "scaled_fp8" in sd: |
| 2004 | raise ValueError("Invalid T5 text encoder model, fp8 scaled is not supported by this node") |
| 2005 | |
| 2006 | # Convert state dict keys from T5 format to the expected format |
| 2007 | if "shared.weight" in sd: |
| 2008 | log.info("Converting T5 text encoder model to the expected format...") |
| 2009 | converted_sd = {} |
| 2010 | |
| 2011 | for key, value in sd.items(): |
| 2012 | # Handle encoder block patterns |
| 2013 | if key.startswith('encoder.block.'): |
| 2014 | parts = key.split('.') |
| 2015 | block_num = parts[2] |
| 2016 | |
| 2017 | # Self-attention components |
| 2018 | if 'layer.0.SelfAttention' in key: |
| 2019 | if key.endswith('.k.weight'): |