MCPcopy
hub / github.com/kohya-ss/sd-scripts / load_control_net

Function load_control_net

tools/original_control_net.py:52–105  ·  view source on GitHub ↗
(v2, unet, model)

Source from the content-addressed store, hash-verified

50
51
52def load_control_net(v2, unet, model):
53 device = unet.device
54
55 # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
56 # state dictを読み込む
57 logger.info(f"ControlNet: loading control SD model : {model}")
58
59 if model_util.is_safetensors(model):
60 ctrl_sd_sd = load_file(model)
61 else:
62 ctrl_sd_sd = torch.load(model, map_location="cpu")
63 ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
64
65 # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
66 is_difference = "difference" in ctrl_sd_sd
67 logger.info(f"ControlNet: loading difference: {is_difference}")
68
69 # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
70 # またTransfer Controlの元weightとなる
71 ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
72
73 # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
74 for key in list(ctrl_unet_sd_sd.keys()):
75 ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
76
77 zero_conv_sd = {}
78 for key in list(ctrl_sd_sd.keys()):
79 if key.startswith("control_"):
80 unet_key = "model.diffusion_" + key[len("control_") :]
81 if unet_key not in ctrl_unet_sd_sd: # zero conv
82 zero_conv_sd[key] = ctrl_sd_sd[key]
83 continue
84 if is_difference: # Transfer Control
85 ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
86 else:
87 ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
88
89 unet_config = model_util.create_unet_diffusers_config(v2)
90 ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
91
92 # ControlNetのU-Netを作成する
93 ctrl_unet = UNet2DConditionModel(**unet_config)
94 info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
95 logger.info(f"ControlNet: loading Control U-Net: {info}")
96
97 # U-Net以外のControlNetを作成する
98 # TODO support middle only
99 ctrl_net = ControlNet()
100 info = ctrl_net.load_state_dict(zero_conv_sd)
101 logger.info("ControlNet: loading ControlNet: {info}")
102
103 ctrl_unet.to(unet.device, dtype=unet.dtype)
104 ctrl_net.to(unet.device, dtype=unet.dtype)
105 return ctrl_unet, ctrl_net
106
107
108def load_preprocess(prep_type: str):

Callers

nothing calls this directly

Calls 6

ControlNetClass · 0.85
toMethod · 0.80
state_dictMethod · 0.45
keysMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected