(state_dict)
| 245 | |
| 246 | |
| 247 | def get_attn2_layers(state_dict): |
| 248 | attn2_layers = [] |
| 249 | for key in state_dict.keys(): |
| 250 | if "attn2." in key: |
| 251 | # Extract the layer number from the key |
| 252 | layer_num = int(key.split(".")[1]) |
| 253 | attn2_layers.append(layer_num) |
| 254 | return tuple(sorted(set(attn2_layers))) |
| 255 | |
| 256 | |
| 257 | def get_pos_embed_max_size(state_dict): |