(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None)
| 140 | |
| 141 | |
| 142 | def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None): |
| 143 | weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0) |
| 144 | bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0) |
| 145 | |
| 146 | new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"] |
| 147 | new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"] |
| 148 | |
| 149 | new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1) |
| 150 | new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1) |
| 151 | new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1) |
| 152 | new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1) |
| 153 | new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1) |
| 154 | new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1) |
| 155 | |
| 156 | new_checkpoint[f"{new_prefix}.to_out.0.weight"] = ( |
| 157 | checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1) |
| 158 | ) |
| 159 | new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1) |
| 160 | |
| 161 | return new_checkpoint |
| 162 | |
| 163 | |
| 164 | def con_pt_to_diffuser(checkpoint_path: str, unet_config): |
no outgoing calls
no test coverage detected
searching dependent graphs…