Updates paths inside resnets to the new naming scheme (local renaming)
(old_list, n_shave_prefix_segments=0)
| 110 | |
| 111 | |
| 112 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): |
| 113 | """ |
| 114 | Updates paths inside resnets to the new naming scheme (local renaming) |
| 115 | """ |
| 116 | mapping = [] |
| 117 | for old_item in old_list: |
| 118 | new_item = old_item.replace("in_layers.0", "norm1") |
| 119 | new_item = new_item.replace("in_layers.2", "conv1") |
| 120 | |
| 121 | new_item = new_item.replace("out_layers.0", "norm2") |
| 122 | new_item = new_item.replace("out_layers.3", "conv2") |
| 123 | |
| 124 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") |
| 125 | new_item = new_item.replace("skip_connection", "conv_shortcut") |
| 126 | |
| 127 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) |
| 128 | |
| 129 | if "temopral_conv" not in old_item: |
| 130 | mapping.append({"old": old_item, "new": new_item}) |
| 131 | |
| 132 | return mapping |
| 133 | |
| 134 | |
| 135 | def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): |
no test coverage detected
searching dependent graphs…