Convert controlnet weights. Args: transformer_type: The type of transformer/controlnet control_state_dict: State dict containing controlnet-specific weights base_state_dict: State dict containing base transformer weights (for shared modules) weights_only: Wh
(
transformer_type: str,
control_state_dict: Dict[str, Any],
base_state_dict: Dict[str, Any],
weights_only: bool = True,
)
| 640 | |
| 641 | |
| 642 | def convert_controlnet( |
| 643 | transformer_type: str, |
| 644 | control_state_dict: Dict[str, Any], |
| 645 | base_state_dict: Dict[str, Any], |
| 646 | weights_only: bool = True, |
| 647 | ): |
| 648 | """ |
| 649 | Convert controlnet weights. |
| 650 | |
| 651 | Args: |
| 652 | transformer_type: The type of transformer/controlnet |
| 653 | control_state_dict: State dict containing controlnet-specific weights |
| 654 | base_state_dict: State dict containing base transformer weights (for shared modules) |
| 655 | weights_only: Whether to use weights_only loading |
| 656 | """ |
| 657 | if transformer_type not in CONTROLNET_CONFIGS: |
| 658 | raise AssertionError(f"{transformer_type} does not define a ControlNet config") |
| 659 | |
| 660 | PREFIX_KEY = "net." |
| 661 | |
| 662 | # Process control-specific keys |
| 663 | for key in list(control_state_dict.keys()): |
| 664 | new_key = key[:] |
| 665 | if new_key.startswith(PREFIX_KEY): |
| 666 | new_key = new_key.removeprefix(PREFIX_KEY) |
| 667 | for replace_key, rename_key in CONTROLNET_KEYS_RENAME_DICT.items(): |
| 668 | new_key = new_key.replace(replace_key, rename_key) |
| 669 | update_state_dict_(control_state_dict, key, new_key) |
| 670 | |
| 671 | for key in list(control_state_dict.keys()): |
| 672 | for special_key, handler_fn_inplace in CONTROLNET_SPECIAL_KEYS_REMAP.items(): |
| 673 | if special_key not in key: |
| 674 | continue |
| 675 | handler_fn_inplace(key, control_state_dict) |
| 676 | |
| 677 | # Copy shared weights from base transformer to controlnet |
| 678 | # These are the duplicated modules: patch_embed_base, time_embed, learnable_pos_embed, img_context_proj, crossattn_proj |
| 679 | shared_module_mappings = { |
| 680 | # transformer key prefix -> controlnet key prefix |
| 681 | "patch_embed.": "patch_embed_base.", |
| 682 | "time_embed.": "time_embed.", |
| 683 | "learnable_pos_embed.": "learnable_pos_embed.", |
| 684 | "img_context_proj.": "img_context_proj.", |
| 685 | "crossattn_proj.": "crossattn_proj.", |
| 686 | } |
| 687 | |
| 688 | for key in list(base_state_dict.keys()): |
| 689 | for transformer_prefix, controlnet_prefix in shared_module_mappings.items(): |
| 690 | if key.startswith(transformer_prefix): |
| 691 | controlnet_key = controlnet_prefix + key[len(transformer_prefix) :] |
| 692 | control_state_dict[controlnet_key] = base_state_dict[key].clone() |
| 693 | print(f"Copied shared weight: {key} -> {controlnet_key}", flush=True) |
| 694 | break |
| 695 | |
| 696 | cfg = CONTROLNET_CONFIGS[transformer_type] |
| 697 | controlnet = CosmosControlNetModel(**cfg) |
| 698 | |
| 699 | expected_keys = set(controlnet.state_dict().keys()) |
no test coverage detected
searching dependent graphs…