(unet_checkpoint_path, verify_param_count=True, sample_size=None)
| 716 | |
| 717 | |
| 718 | def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None): |
| 719 | orig_path = unet_checkpoint_path |
| 720 | |
| 721 | original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml")) |
| 722 | original_unet_config = original_unet_config["params"] |
| 723 | |
| 724 | unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config) |
| 725 | unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int( |
| 726 | original_unet_config["channel_mult"].split(",")[-1] |
| 727 | ) |
| 728 | if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]: |
| 729 | unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"] |
| 730 | unet_diffusers_config["class_embed_type"] = "timestep" |
| 731 | unet_diffusers_config["addition_embed_type"] = "text" |
| 732 | |
| 733 | unet_diffusers_config["time_embedding_act_fn"] = "gelu" |
| 734 | unet_diffusers_config["resnet_skip_time_act"] = True |
| 735 | unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071 |
| 736 | unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071 |
| 737 | unet_diffusers_config["only_cross_attention"] = ( |
| 738 | bool(original_unet_config["disable_self_attentions"]) |
| 739 | if ( |
| 740 | "disable_self_attentions" in original_unet_config |
| 741 | and isinstance(original_unet_config["disable_self_attentions"], int) |
| 742 | ) |
| 743 | else True |
| 744 | ) |
| 745 | |
| 746 | if sample_size is None: |
| 747 | unet_diffusers_config["sample_size"] = original_unet_config["image_size"] |
| 748 | else: |
| 749 | # The second upscaler unet's sample size is incorrectly specified |
| 750 | # in the config and is instead hardcoded in source |
| 751 | unet_diffusers_config["sample_size"] = sample_size |
| 752 | |
| 753 | unet_checkpoint = torch.load(os.path.join(unet_checkpoint_path, "pytorch_model.bin"), map_location="cpu") |
| 754 | |
| 755 | if verify_param_count: |
| 756 | # check that architecture matches - is a bit slow |
| 757 | verify_param_count(orig_path, unet_diffusers_config) |
| 758 | |
| 759 | converted_unet_checkpoint = superres_convert_ldm_unet_checkpoint( |
| 760 | unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path |
| 761 | ) |
| 762 | converted_keys = converted_unet_checkpoint.keys() |
| 763 | |
| 764 | model = UNet2DConditionModel(**unet_diffusers_config) |
| 765 | expected_weights = model.state_dict().keys() |
| 766 | |
| 767 | diff_c_e = set(converted_keys) - set(expected_weights) |
| 768 | diff_e_c = set(expected_weights) - set(converted_keys) |
| 769 | |
| 770 | assert len(diff_e_c) == 0, f"Expected, but not converted: {diff_e_c}" |
| 771 | assert len(diff_c_e) == 0, f"Converted, but not expected: {diff_c_e}" |
| 772 | |
| 773 | model.load_state_dict(converted_unet_checkpoint) |
| 774 | |
| 775 | return model |
no test coverage detected
searching dependent graphs…