(args)
| 49 | |
| 50 | |
| 51 | def load_original_checkpoint(args): |
| 52 | if args.original_state_dict_repo_id is not None: |
| 53 | ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) |
| 54 | elif args.checkpoint_path is not None: |
| 55 | ckpt_path = args.checkpoint_path |
| 56 | else: |
| 57 | raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") |
| 58 | |
| 59 | original_state_dict = safetensors.torch.load_file(ckpt_path) |
| 60 | return original_state_dict |
| 61 | |
| 62 | |
| 63 | # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; |
no outgoing calls
no test coverage detected
searching dependent graphs…