MCPcopy
hub / github.com/kohya-ss/sd-scripts / get_weighted_text_embeddings

Function get_weighted_text_embeddings

gen_img.py:1194–1335  ·  view source on GitHub ↗
(
    is_sdxl: bool,
    tokenizer: CLIPTokenizer,
    text_encoder: CLIPTextModel,
    prompt: Union[str, List[str]],
    uncond_prompt: Optional[Union[str, List[str]]] = None,
    max_embeddings_multiples: Optional[int] = 1,
    no_boseos_middle: Optional[bool] = False,
    skip_parsing: Optional[bool] = False,
    skip_weighting: Optional[bool] = False,
    clip_skip: int = 1,
    token_replacer=None,
    device=None,
    emb_normalize_mode: Optional[str] = "original",  # "original", "abs", "none"
    **kwargs,
)

Source from the content-addressed store, hash-verified

1192
1193
1194def get_weighted_text_embeddings(
1195 is_sdxl: bool,
1196 tokenizer: CLIPTokenizer,
1197 text_encoder: CLIPTextModel,
1198 prompt: Union[str, List[str]],
1199 uncond_prompt: Optional[Union[str, List[str]]] = None,
1200 max_embeddings_multiples: Optional[int] = 1,
1201 no_boseos_middle: Optional[bool] = False,
1202 skip_parsing: Optional[bool] = False,
1203 skip_weighting: Optional[bool] = False,
1204 clip_skip: int = 1,
1205 token_replacer=None,
1206 device=None,
1207 emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none"
1208 **kwargs,
1209):
1210 max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
1211 if isinstance(prompt, str):
1212 prompt = [prompt]
1213
1214 # split the prompts with "AND". each prompt must have the same number of splits
1215 new_prompts = []
1216 for p in prompt:
1217 new_prompts.extend(p.split(" AND "))
1218 prompt = new_prompts
1219
1220 if not skip_parsing:
1221 prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2)
1222 if uncond_prompt is not None:
1223 if isinstance(uncond_prompt, str):
1224 uncond_prompt = [uncond_prompt]
1225 uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2)
1226 else:
1227 prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
1228 prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
1229 if uncond_prompt is not None:
1230 if isinstance(uncond_prompt, str):
1231 uncond_prompt = [uncond_prompt]
1232 uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids]
1233 uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
1234
1235 # round up the longest length of tokens to a multiple of (model_max_length - 2)
1236 max_length = max([len(token) for token in prompt_tokens])
1237 if uncond_prompt is not None:
1238 max_length = max(max_length, max([len(token) for token in uncond_tokens]))
1239
1240 max_embeddings_multiples = min(
1241 max_embeddings_multiples,
1242 (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
1243 )
1244 max_embeddings_multiples = max(1, max_embeddings_multiples)
1245 max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
1246
1247 # pad the length of tokens and weights
1248 bos = tokenizer.bos_token_id
1249 eos = tokenizer.eos_token_id
1250 pad = tokenizer.pad_token_id
1251 prompt_tokens, prompt_weights = pad_tokens_and_weights(

Callers 1

__call__Method · 0.70

Calls 4

toMethod · 0.80
get_prompts_with_weightsFunction · 0.70
pad_tokens_and_weightsFunction · 0.70

Tested by

no test coverage detected