(v2, unet, model)
| 50 | |
| 51 | |
| 52 | def load_control_net(v2, unet, model): |
| 53 | device = unet.device |
| 54 | |
| 55 | # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む |
| 56 | # state dictを読み込む |
| 57 | logger.info(f"ControlNet: loading control SD model : {model}") |
| 58 | |
| 59 | if model_util.is_safetensors(model): |
| 60 | ctrl_sd_sd = load_file(model) |
| 61 | else: |
| 62 | ctrl_sd_sd = torch.load(model, map_location="cpu") |
| 63 | ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd) |
| 64 | |
| 65 | # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む |
| 66 | is_difference = "difference" in ctrl_sd_sd |
| 67 | logger.info(f"ControlNet: loading difference: {is_difference}") |
| 68 | |
| 69 | # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく |
| 70 | # またTransfer Controlの元weightとなる |
| 71 | ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict()) |
| 72 | |
| 73 | # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける |
| 74 | for key in list(ctrl_unet_sd_sd.keys()): |
| 75 | ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone() |
| 76 | |
| 77 | zero_conv_sd = {} |
| 78 | for key in list(ctrl_sd_sd.keys()): |
| 79 | if key.startswith("control_"): |
| 80 | unet_key = "model.diffusion_" + key[len("control_") :] |
| 81 | if unet_key not in ctrl_unet_sd_sd: # zero conv |
| 82 | zero_conv_sd[key] = ctrl_sd_sd[key] |
| 83 | continue |
| 84 | if is_difference: # Transfer Control |
| 85 | ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype) |
| 86 | else: |
| 87 | ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype) |
| 88 | |
| 89 | unet_config = model_util.create_unet_diffusers_config(v2) |
| 90 | ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict |
| 91 | |
| 92 | # ControlNetのU-Netを作成する |
| 93 | ctrl_unet = UNet2DConditionModel(**unet_config) |
| 94 | info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) |
| 95 | logger.info(f"ControlNet: loading Control U-Net: {info}") |
| 96 | |
| 97 | # U-Net以外のControlNetを作成する |
| 98 | # TODO support middle only |
| 99 | ctrl_net = ControlNet() |
| 100 | info = ctrl_net.load_state_dict(zero_conv_sd) |
| 101 | logger.info("ControlNet: loading ControlNet: {info}") |
| 102 | |
| 103 | ctrl_unet.to(unet.device, dtype=unet.dtype) |
| 104 | ctrl_net.to(unet.device, dtype=unet.dtype) |
| 105 | return ctrl_unet, ctrl_net |
| 106 | |
| 107 | |
| 108 | def load_preprocess(prep_type: str): |
nothing calls this directly
no test coverage detected