MCPcopy Index your code
hub / github.com/huggingface/diffusers / get_super_res_unet

Function get_super_res_unet

scripts/convert_if.py:718–775  ·  view source on GitHub ↗
(unet_checkpoint_path, verify_param_count=True, sample_size=None)

Source from the content-addressed store, hash-verified

716
717
718def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None):
719 orig_path = unet_checkpoint_path
720
721 original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml"))
722 original_unet_config = original_unet_config["params"]
723
724 unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config)
725 unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int(
726 original_unet_config["channel_mult"].split(",")[-1]
727 )
728 if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]:
729 unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
730 unet_diffusers_config["class_embed_type"] = "timestep"
731 unet_diffusers_config["addition_embed_type"] = "text"
732
733 unet_diffusers_config["time_embedding_act_fn"] = "gelu"
734 unet_diffusers_config["resnet_skip_time_act"] = True
735 unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071
736 unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071
737 unet_diffusers_config["only_cross_attention"] = (
738 bool(original_unet_config["disable_self_attentions"])
739 if (
740 "disable_self_attentions" in original_unet_config
741 and isinstance(original_unet_config["disable_self_attentions"], int)
742 )
743 else True
744 )
745
746 if sample_size is None:
747 unet_diffusers_config["sample_size"] = original_unet_config["image_size"]
748 else:
749 # The second upscaler unet's sample size is incorrectly specified
750 # in the config and is instead hardcoded in source
751 unet_diffusers_config["sample_size"] = sample_size
752
753 unet_checkpoint = torch.load(os.path.join(unet_checkpoint_path, "pytorch_model.bin"), map_location="cpu")
754
755 if verify_param_count:
756 # check that architecture matches - is a bit slow
757 verify_param_count(orig_path, unet_diffusers_config)
758
759 converted_unet_checkpoint = superres_convert_ldm_unet_checkpoint(
760 unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
761 )
762 converted_keys = converted_unet_checkpoint.keys()
763
764 model = UNet2DConditionModel(**unet_diffusers_config)
765 expected_weights = model.state_dict().keys()
766
767 diff_c_e = set(converted_keys) - set(expected_weights)
768 diff_e_c = set(expected_weights) - set(converted_keys)
769
770 assert len(diff_e_c) == 0, f"Expected, but not converted: {diff_e_c}"
771 assert len(diff_c_e) == 0, f"Converted, but not expected: {diff_c_e}"
772
773 model.load_state_dict(converted_unet_checkpoint)
774
775 return model

Callers 1

Calls 8

load_state_dictMethod · 0.95
verify_param_countFunction · 0.85
splitMethod · 0.80
loadMethod · 0.45
state_dictMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…