| 113 | debug(f"Prompt encode: time={(time.time() - t0):.3f}") |
| 114 | |
| 115 | def checkcache(self, p) -> bool: |
| 116 | if shared.opts.sd_textencoder_cache_size == 0: |
| 117 | return False |
| 118 | if self.scheduled_prompt: |
| 119 | debug("Prompt cache: scheduled prompt") |
| 120 | cache.clear() |
| 121 | return False |
| 122 | if self.attention != self.prompt_attention_value: |
| 123 | debug(f"Prompt cache: parser={self.prompt_attention_value} changed") |
| 124 | cache.clear() |
| 125 | return False |
| 126 | |
| 127 | def flatten(xss): |
| 128 | return [x for xs in xss for x in xs] |
| 129 | |
| 130 | # unpack EN data in case of TE LoRA |
| 131 | en_data = p.network_data |
| 132 | en_data = [idx.items for item in en_data.values() for idx in item] |
| 133 | effective_batch = 1 if self.allsame else self.batchsize |
| 134 | key = str([self.prompts, self.negative_prompts, effective_batch, self.clip_skip, self.steps, en_data]) |
| 135 | item = cache.get(key) |
| 136 | if not item: |
| 137 | if not any(flatten(emb) for emb in [self.prompt_embeds, |
| 138 | self.negative_prompt_embeds, |
| 139 | self.positive_pooleds, |
| 140 | self.negative_pooleds, |
| 141 | self.prompt_attention_masks, |
| 142 | self.negative_prompt_attention_masks]): |
| 143 | return False |
| 144 | else: |
| 145 | cache[key] = {'prompt_embeds': self.prompt_embeds, |
| 146 | 'negative_prompt_embeds': self.negative_prompt_embeds, |
| 147 | 'positive_pooleds': self.positive_pooleds, |
| 148 | 'negative_pooleds': self.negative_pooleds, |
| 149 | 'prompt_attention_masks': self.prompt_attention_masks, |
| 150 | 'negative_prompt_attention_masks': self.negative_prompt_attention_masks, |
| 151 | } |
| 152 | debug(f"Prompt cache: add={key}") |
| 153 | while len(cache) > int(shared.opts.sd_textencoder_cache_size): |
| 154 | cache.popitem(last=False) |
| 155 | return True |
| 156 | if item: |
| 157 | self.__dict__.update(cache[key]) |
| 158 | cache.move_to_end(key) |
| 159 | if self.allsame and len(self.prompt_embeds) < self.batchsize: |
| 160 | self.prompt_embeds = [self.prompt_embeds[0]] * self.batchsize |
| 161 | self.positive_pooleds = [self.positive_pooleds[0]] * self.batchsize |
| 162 | self.negative_prompt_embeds = [self.negative_prompt_embeds[0]] * self.batchsize |
| 163 | self.negative_pooleds = [self.negative_pooleds[0]] * self.batchsize |
| 164 | self.prompt_attention_masks = [self.prompt_attention_masks[0]] * self.batchsize |
| 165 | self.negative_prompt_attention_masks = [self.negative_prompt_attention_masks[0]] * self.batchsize |
| 166 | debug(f"Prompt cache: get={key}") |
| 167 | return True |
| 168 | |
| 169 | def compare_prompts(self): |
| 170 | same = (self.prompts == [self.prompts[0]] * len(self.prompts) and self.negative_prompts == [self.negative_prompts[0]] * len(self.negative_prompts)) |