(
is_sdxl: bool,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 1,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
clip_skip: int = 1,
token_replacer=None,
device=None,
emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none"
**kwargs,
)
| 1192 | |
| 1193 | |
| 1194 | def get_weighted_text_embeddings( |
| 1195 | is_sdxl: bool, |
| 1196 | tokenizer: CLIPTokenizer, |
| 1197 | text_encoder: CLIPTextModel, |
| 1198 | prompt: Union[str, List[str]], |
| 1199 | uncond_prompt: Optional[Union[str, List[str]]] = None, |
| 1200 | max_embeddings_multiples: Optional[int] = 1, |
| 1201 | no_boseos_middle: Optional[bool] = False, |
| 1202 | skip_parsing: Optional[bool] = False, |
| 1203 | skip_weighting: Optional[bool] = False, |
| 1204 | clip_skip: int = 1, |
| 1205 | token_replacer=None, |
| 1206 | device=None, |
| 1207 | emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none" |
| 1208 | **kwargs, |
| 1209 | ): |
| 1210 | max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 |
| 1211 | if isinstance(prompt, str): |
| 1212 | prompt = [prompt] |
| 1213 | |
| 1214 | # split the prompts with "AND". each prompt must have the same number of splits |
| 1215 | new_prompts = [] |
| 1216 | for p in prompt: |
| 1217 | new_prompts.extend(p.split(" AND ")) |
| 1218 | prompt = new_prompts |
| 1219 | |
| 1220 | if not skip_parsing: |
| 1221 | prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) |
| 1222 | if uncond_prompt is not None: |
| 1223 | if isinstance(uncond_prompt, str): |
| 1224 | uncond_prompt = [uncond_prompt] |
| 1225 | uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) |
| 1226 | else: |
| 1227 | prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] |
| 1228 | prompt_weights = [[1.0] * len(token) for token in prompt_tokens] |
| 1229 | if uncond_prompt is not None: |
| 1230 | if isinstance(uncond_prompt, str): |
| 1231 | uncond_prompt = [uncond_prompt] |
| 1232 | uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] |
| 1233 | uncond_weights = [[1.0] * len(token) for token in uncond_tokens] |
| 1234 | |
| 1235 | # round up the longest length of tokens to a multiple of (model_max_length - 2) |
| 1236 | max_length = max([len(token) for token in prompt_tokens]) |
| 1237 | if uncond_prompt is not None: |
| 1238 | max_length = max(max_length, max([len(token) for token in uncond_tokens])) |
| 1239 | |
| 1240 | max_embeddings_multiples = min( |
| 1241 | max_embeddings_multiples, |
| 1242 | (max_length - 1) // (tokenizer.model_max_length - 2) + 1, |
| 1243 | ) |
| 1244 | max_embeddings_multiples = max(1, max_embeddings_multiples) |
| 1245 | max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 |
| 1246 | |
| 1247 | # pad the length of tokens and weights |
| 1248 | bos = tokenizer.bos_token_id |
| 1249 | eos = tokenizer.eos_token_id |
| 1250 | pad = tokenizer.pad_token_id |
| 1251 | prompt_tokens, prompt_weights = pad_tokens_and_weights( |
no test coverage detected