(unet_state_dict)
| 92 | |
| 93 | |
| 94 | def convert_unet_state_dict(unet_state_dict): |
| 95 | # buyer beware: this is a *brittle* function, |
| 96 | # and correct output requires that all of these pieces interact in |
| 97 | # the exact order in which I have arranged them. |
| 98 | mapping = {k: k for k in unet_state_dict.keys()} |
| 99 | for sd_name, hf_name in unet_conversion_map: |
| 100 | mapping[hf_name] = sd_name |
| 101 | for k, v in mapping.items(): |
| 102 | if "resnets" in k: |
| 103 | for sd_part, hf_part in unet_conversion_map_resnet: |
| 104 | v = v.replace(hf_part, sd_part) |
| 105 | mapping[k] = v |
| 106 | for k, v in mapping.items(): |
| 107 | for sd_part, hf_part in unet_conversion_map_layer: |
| 108 | v = v.replace(hf_part, sd_part) |
| 109 | mapping[k] = v |
| 110 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} |
| 111 | return new_state_dict |
| 112 | |
| 113 | |
| 114 | # ================# |
no outgoing calls
no test coverage detected