(pipe, embeds, empty_embedding_providers=None, diffusers_zeros_prompt_pad=None)
| 543 | |
| 544 | |
| 545 | def pad_to_same_length(pipe, embeds, empty_embedding_providers=None, diffusers_zeros_prompt_pad=None): |
| 546 | if not hasattr(pipe, 'encode_prompt') and ('StableCascade' not in pipe.__class__.__name__): |
| 547 | return embeds |
| 548 | device = devices.device |
| 549 | _zeros_pad = diffusers_zeros_prompt_pad if diffusers_zeros_prompt_pad is not None else shared.opts.diffusers_zeros_prompt_pad |
| 550 | if _zeros_pad or 'StableDiffusion3' in pipe.__class__.__name__: |
| 551 | empty_embed = [torch.zeros((1, 77, embeds[0].shape[2]), device=device, dtype=embeds[0].dtype)] |
| 552 | else: |
| 553 | try: |
| 554 | if 'StableCascade' in pipe.__class__.__name__: |
| 555 | empty_embed = empty_embedding_providers[0].get_embeddings_for_weighted_prompt_fragments(text_batch=[[""]], fragment_weights_batch=[[1]], should_return_tokens=False, device=device) |
| 556 | empty_embed = [empty_embed] |
| 557 | else: |
| 558 | empty_embed = pipe.encode_prompt("") |
| 559 | except TypeError: # SD1.5 |
| 560 | empty_embed = pipe.encode_prompt("", device, 1, False) |
| 561 | max_token_count = max([embed.shape[1] for embed in embeds]) |
| 562 | repeats = max_token_count - min([embed.shape[1] for embed in embeds]) |
| 563 | empty_batched = empty_embed[0].to(embeds[0].device).repeat(embeds[0].shape[0], repeats // empty_embed[0].shape[1], 1) |
| 564 | for i, embed in enumerate(embeds): |
| 565 | if embed.shape[1] < max_token_count: |
| 566 | embed = torch.cat([embed, empty_batched], dim=1) |
| 567 | embeds[i] = embed |
| 568 | return embeds |
| 569 | |
| 570 | |
| 571 | def split_prompts(pipe, prompt, SD3 = False): |
no test coverage detected