Updates paths inside resnets to the new naming scheme (local renaming)
(old_list, n_shave_prefix_segments=0)
| 48 | |
| 49 | |
| 50 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): |
| 51 | """ |
| 52 | Updates paths inside resnets to the new naming scheme (local renaming) |
| 53 | """ |
| 54 | mapping = [] |
| 55 | for old_item in old_list: |
| 56 | new_item = old_item.replace("in_layers.0", "norm1") |
| 57 | new_item = new_item.replace("in_layers.2", "conv1") |
| 58 | |
| 59 | new_item = new_item.replace("out_layers.0", "norm2") |
| 60 | new_item = new_item.replace("out_layers.3", "conv2") |
| 61 | |
| 62 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") |
| 63 | new_item = new_item.replace("skip_connection", "conv_shortcut") |
| 64 | |
| 65 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) |
| 66 | |
| 67 | mapping.append({"old": old_item, "new": new_item}) |
| 68 | |
| 69 | return mapping |
| 70 | |
| 71 | |
| 72 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): |
no test coverage detected