| 23 | |
| 24 | |
| 25 | def patched_encode_token_weights(self, token_weight_pairs): |
| 26 | to_encode = list() |
| 27 | max_token_len = 0 |
| 28 | has_weights = False |
| 29 | for x in token_weight_pairs: |
| 30 | tokens = list(map(lambda a: a[0], x)) |
| 31 | max_token_len = max(len(tokens), max_token_len) |
| 32 | has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) |
| 33 | to_encode.append(tokens) |
| 34 | |
| 35 | sections = len(to_encode) |
| 36 | if has_weights or sections == 0: |
| 37 | to_encode.append(ldm_patched.modules.sd1_clip.gen_empty_tokens(self.special_tokens, max_token_len)) |
| 38 | |
| 39 | out, pooled = self.encode(to_encode) |
| 40 | if pooled is not None: |
| 41 | first_pooled = pooled[0:1].to(ldm_patched.modules.model_management.intermediate_device()) |
| 42 | else: |
| 43 | first_pooled = pooled |
| 44 | |
| 45 | output = [] |
| 46 | for k in range(0, sections): |
| 47 | z = out[k:k + 1] |
| 48 | if has_weights: |
| 49 | original_mean = z.mean() |
| 50 | z_empty = out[-1] |
| 51 | for i in range(len(z)): |
| 52 | for j in range(len(z[i])): |
| 53 | weight = token_weight_pairs[k][j][1] |
| 54 | if weight != 1.0: |
| 55 | z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] |
| 56 | new_mean = z.mean() |
| 57 | z = z * (original_mean / new_mean) |
| 58 | output.append(z) |
| 59 | |
| 60 | if len(output) == 0: |
| 61 | return out[-1:].to(ldm_patched.modules.model_management.intermediate_device()), first_pooled |
| 62 | return torch.cat(output, dim=-2).to(ldm_patched.modules.model_management.intermediate_device()), first_pooled |
| 63 | |
| 64 | |
| 65 | def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last", layer_idx=None, |