(state_dict_keys, key_prefix)
| 13 | |
| 14 | |
| 15 | def calculate_layers(state_dict_keys, key_prefix): |
| 16 | dit_layers = set() |
| 17 | for k in state_dict_keys: |
| 18 | if key_prefix in k: |
| 19 | dit_layers.add(int(k.split(".")[2])) |
| 20 | print(f"{key_prefix}: {len(dit_layers)}") |
| 21 | return len(dit_layers) |
| 22 | |
| 23 | |
| 24 | # similar to SD3 but only for the last norm layer |
no test coverage detected
searching dependent graphs…