(args)
| 31 | |
| 32 | |
| 33 | def merge(args): |
| 34 | if args.precision == "fp16": |
| 35 | dtype = torch.float16 |
| 36 | elif args.precision == "bf16": |
| 37 | dtype = torch.bfloat16 |
| 38 | else: |
| 39 | dtype = torch.float |
| 40 | |
| 41 | if args.saving_precision == "fp16": |
| 42 | save_dtype = torch.float16 |
| 43 | elif args.saving_precision == "bf16": |
| 44 | save_dtype = torch.bfloat16 |
| 45 | else: |
| 46 | save_dtype = torch.float |
| 47 | |
| 48 | # check if all models are safetensors |
| 49 | for model in args.models: |
| 50 | if not model.endswith("safetensors"): |
| 51 | logger.info(f"Model {model} is not a safetensors model") |
| 52 | exit() |
| 53 | if not os.path.isfile(model): |
| 54 | logger.info(f"Model {model} does not exist") |
| 55 | exit() |
| 56 | |
| 57 | assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" |
| 58 | |
| 59 | # load and merge |
| 60 | ratio = 1.0 / len(args.models) # default |
| 61 | supplementary_key_ratios = {} # [key] = ratio, for keys not in all models, add later |
| 62 | |
| 63 | merged_sd = None |
| 64 | first_model_keys = set() # check missing keys in other models |
| 65 | for i, model in enumerate(args.models): |
| 66 | if args.ratios is not None: |
| 67 | ratio = args.ratios[i] |
| 68 | |
| 69 | if merged_sd is None: |
| 70 | # load first model |
| 71 | logger.info(f"Loading model {model}, ratio = {ratio}...") |
| 72 | merged_sd = {} |
| 73 | with safe_open(model, framework="pt", device=args.device) as f: |
| 74 | for key in tqdm(f.keys()): |
| 75 | value = f.get_tensor(key) |
| 76 | _, key = replace_text_encoder_key(key) |
| 77 | |
| 78 | first_model_keys.add(key) |
| 79 | |
| 80 | if not is_unet_key(key) and args.unet_only: |
| 81 | supplementary_key_ratios[key] = 1.0 # use first model's value for VAE or TextEncoder |
| 82 | continue |
| 83 | |
| 84 | value = ratio * value.to(dtype) # first model's value * ratio |
| 85 | merged_sd[key] = value |
| 86 | |
| 87 | logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) |
| 88 | continue |
| 89 | |
| 90 | # load other models |
no test coverage detected