MCPcopy Index your code
hub / github.com/huggingface/diffusers / convert_controlnet

Function convert_controlnet

scripts/convert_cosmos_to_diffusers.py:642–715  ·  view source on GitHub ↗

Convert controlnet weights. Args: transformer_type: The type of transformer/controlnet control_state_dict: State dict containing controlnet-specific weights base_state_dict: State dict containing base transformer weights (for shared modules) weights_only: Wh

(
    transformer_type: str,
    control_state_dict: Dict[str, Any],
    base_state_dict: Dict[str, Any],
    weights_only: bool = True,
)

Source from the content-addressed store, hash-verified

640
641
642def convert_controlnet(
643 transformer_type: str,
644 control_state_dict: Dict[str, Any],
645 base_state_dict: Dict[str, Any],
646 weights_only: bool = True,
647):
648 """
649 Convert controlnet weights.
650
651 Args:
652 transformer_type: The type of transformer/controlnet
653 control_state_dict: State dict containing controlnet-specific weights
654 base_state_dict: State dict containing base transformer weights (for shared modules)
655 weights_only: Whether to use weights_only loading
656 """
657 if transformer_type not in CONTROLNET_CONFIGS:
658 raise AssertionError(f"{transformer_type} does not define a ControlNet config")
659
660 PREFIX_KEY = "net."
661
662 # Process control-specific keys
663 for key in list(control_state_dict.keys()):
664 new_key = key[:]
665 if new_key.startswith(PREFIX_KEY):
666 new_key = new_key.removeprefix(PREFIX_KEY)
667 for replace_key, rename_key in CONTROLNET_KEYS_RENAME_DICT.items():
668 new_key = new_key.replace(replace_key, rename_key)
669 update_state_dict_(control_state_dict, key, new_key)
670
671 for key in list(control_state_dict.keys()):
672 for special_key, handler_fn_inplace in CONTROLNET_SPECIAL_KEYS_REMAP.items():
673 if special_key not in key:
674 continue
675 handler_fn_inplace(key, control_state_dict)
676
677 # Copy shared weights from base transformer to controlnet
678 # These are the duplicated modules: patch_embed_base, time_embed, learnable_pos_embed, img_context_proj, crossattn_proj
679 shared_module_mappings = {
680 # transformer key prefix -> controlnet key prefix
681 "patch_embed.": "patch_embed_base.",
682 "time_embed.": "time_embed.",
683 "learnable_pos_embed.": "learnable_pos_embed.",
684 "img_context_proj.": "img_context_proj.",
685 "crossattn_proj.": "crossattn_proj.",
686 }
687
688 for key in list(base_state_dict.keys()):
689 for transformer_prefix, controlnet_prefix in shared_module_mappings.items():
690 if key.startswith(transformer_prefix):
691 controlnet_key = controlnet_prefix + key[len(transformer_prefix) :]
692 control_state_dict[controlnet_key] = base_state_dict[key].clone()
693 print(f"Copied shared weight: {key} -> {controlnet_key}", flush=True)
694 break
695
696 cfg = CONTROLNET_CONFIGS[transformer_type]
697 controlnet = CosmosControlNetModel(**cfg)
698
699 expected_keys = set(controlnet.state_dict().keys())

Callers 1

Calls 4

update_state_dict_Function · 0.70
state_dictMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…