(self, key, step=0)
| 270 | _clone(self.negative_prompt_attention_masks) |
| 271 | |
| 272 | def __call__(self, key, step=0): |
| 273 | batch = getattr(self, key) |
| 274 | res = [] |
| 275 | try: |
| 276 | if len(batch) == 0 or len(batch[0]) == 0: |
| 277 | return None # flux has no negative prompts |
| 278 | if isinstance(batch[0][0], list) and len(batch[0][0]) == 2 and isinstance(batch[0][0][1], torch.Tensor) and batch[0][0][1].shape[0] == 32: |
| 279 | # hidream uses a list of t5 + llama prompt embeds: [t5_embeds, llama_embeds] |
| 280 | # t5_embeds shape: [batch_size, seq_len, dim] |
| 281 | # llama_embeds shape: [number_of_hidden_states, batch_size, seq_len, dim] |
| 282 | res2 = [] |
| 283 | for i in range(self.batchsize): |
| 284 | if len(batch[i]) == 0: # if asking for a null key, ie pooled on SD1.5 |
| 285 | return None |
| 286 | try: |
| 287 | res.append(batch[i][step][0]) |
| 288 | res2.append(batch[i][step][1]) |
| 289 | except IndexError: |
| 290 | # if not scheduled, return default |
| 291 | res.append(batch[i][0][0]) |
| 292 | res2.append(batch[i][0][1]) |
| 293 | res = [torch.cat(res, dim=0), torch.cat(res2, dim=1)] |
| 294 | return res |
| 295 | else: |
| 296 | for i in range(self.batchsize): |
| 297 | if len(batch[i]) == 0: # if asking for a null key, ie pooled on SD1.5 |
| 298 | return None |
| 299 | try: |
| 300 | res.append(batch[i][step]) |
| 301 | except IndexError: |
| 302 | res.append(batch[i][0]) # if not scheduled, return default |
| 303 | if any(res[0].shape[1] != r.shape[1] for r in res): |
| 304 | res = pad_to_same_length(self.pipe, res, diffusers_zeros_prompt_pad=self.diffusers_zeros_prompt_pad) |
| 305 | return torch.cat(res) |
| 306 | except Exception as e: |
| 307 | log.error(f"Prompt encode: {e}") |
| 308 | return None |
| 309 | |
| 310 | |
| 311 | def compel_hijack(self, token_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: |
nothing calls this directly
no test coverage detected