MCPcopy
hub / github.com/kohya-ss/sd-scripts / merge

Function merge

tools/merge_models.py:33–148  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

31
32
33def merge(args):
34 if args.precision == "fp16":
35 dtype = torch.float16
36 elif args.precision == "bf16":
37 dtype = torch.bfloat16
38 else:
39 dtype = torch.float
40
41 if args.saving_precision == "fp16":
42 save_dtype = torch.float16
43 elif args.saving_precision == "bf16":
44 save_dtype = torch.bfloat16
45 else:
46 save_dtype = torch.float
47
48 # check if all models are safetensors
49 for model in args.models:
50 if not model.endswith("safetensors"):
51 logger.info(f"Model {model} is not a safetensors model")
52 exit()
53 if not os.path.isfile(model):
54 logger.info(f"Model {model} does not exist")
55 exit()
56
57 assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
58
59 # load and merge
60 ratio = 1.0 / len(args.models) # default
61 supplementary_key_ratios = {} # [key] = ratio, for keys not in all models, add later
62
63 merged_sd = None
64 first_model_keys = set() # check missing keys in other models
65 for i, model in enumerate(args.models):
66 if args.ratios is not None:
67 ratio = args.ratios[i]
68
69 if merged_sd is None:
70 # load first model
71 logger.info(f"Loading model {model}, ratio = {ratio}...")
72 merged_sd = {}
73 with safe_open(model, framework="pt", device=args.device) as f:
74 for key in tqdm(f.keys()):
75 value = f.get_tensor(key)
76 _, key = replace_text_encoder_key(key)
77
78 first_model_keys.add(key)
79
80 if not is_unet_key(key) and args.unet_only:
81 supplementary_key_ratios[key] = 1.0 # use first model's value for VAE or TextEncoder
82 continue
83
84 value = ratio * value.to(dtype) # first model's value * ratio
85 merged_sd[key] = value
86
87 logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
88 continue
89
90 # load other models

Callers 1

merge_models.pyFile · 0.70

Calls 6

replace_text_encoder_keyFunction · 0.85
is_unet_keyFunction · 0.85
addMethod · 0.80
toMethod · 0.80
keysMethod · 0.45
get_tensorMethod · 0.45

Tested by

no test coverage detected