(ckpt_path)
| 22 | |
| 23 | |
| 24 | def load_original_checkpoint(ckpt_path): |
| 25 | original_state_dict = safetensors.torch.load_file(ckpt_path) |
| 26 | keys = list(original_state_dict.keys()) |
| 27 | for k in keys: |
| 28 | if "model.diffusion_model." in k: |
| 29 | original_state_dict[k.replace("model.diffusion_model.", "")] = original_state_dict.pop(k) |
| 30 | |
| 31 | return original_state_dict |
| 32 | |
| 33 | |
| 34 | # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; |