(self, pipe, positive_prompt, negative_prompt, batchidx)
| 207 | self.negative_prompt_attention_masks[batchidx].append(self.negative_prompt_attention_masks[batchidx][idx]) |
| 208 | |
| 209 | def encode(self, pipe, positive_prompt, negative_prompt, batchidx): |
| 210 | if positive_prompt is None: |
| 211 | positive_prompt = '' |
| 212 | if negative_prompt is None: |
| 213 | negative_prompt = '' |
| 214 | global last_attention # pylint: disable=global-statement |
| 215 | self.attention = self.prompt_attention_value |
| 216 | last_attention = self.attention |
| 217 | if self.attention == "xhinker": |
| 218 | ( |
| 219 | prompt_embed, |
| 220 | positive_pooled, |
| 221 | prompt_attention_mask, |
| 222 | negative_embed, |
| 223 | negative_pooled, |
| 224 | negative_prompt_attention_mask |
| 225 | ) = get_xhinker_text_embeddings(pipe, positive_prompt, negative_prompt, self.clip_skip) |
| 226 | else: |
| 227 | ( |
| 228 | prompt_embed, |
| 229 | positive_pooled, |
| 230 | prompt_attention_mask, |
| 231 | negative_embed, |
| 232 | negative_pooled, |
| 233 | negative_prompt_attention_mask |
| 234 | ) = get_weighted_text_embeddings(pipe, positive_prompt, negative_prompt, self.clip_skip, |
| 235 | prompt_mean_norm=self.prompt_mean_norm, diffusers_zeros_prompt_pad=self.diffusers_zeros_prompt_pad, te_pooled_embeds=self.te_pooled_embeds) |
| 236 | def _store(target, value): |
| 237 | if value is None: |
| 238 | return |
| 239 | # scheduled prompts need to keep all slices, unscheduled can overwrite |
| 240 | if self.scheduled_prompt and len(target[batchidx]) > 0: |
| 241 | target[batchidx].append(value) |
| 242 | else: |
| 243 | target[batchidx] = [value] |
| 244 | |
| 245 | _store(self.prompt_embeds, prompt_embed) |
| 246 | _store(self.negative_prompt_embeds, negative_embed) |
| 247 | _store(self.positive_pooleds, positive_pooled) |
| 248 | _store(self.negative_pooleds, negative_pooled) |
| 249 | _store(self.prompt_attention_masks, prompt_attention_mask) |
| 250 | _store(self.negative_prompt_attention_masks, negative_prompt_attention_mask) |
| 251 | if debug_enabled: |
| 252 | get_tokens(pipe, 'positive', positive_prompt) |
| 253 | get_tokens(pipe, 'negative', negative_prompt) |
| 254 | |
| 255 | def clone_embeds(self, batchidx, idx): |
| 256 | def _clone(target): |
no test coverage detected