Get the inner model (model.model or model.transformer).
(model)
| 73 | |
| 74 | |
| 75 | def get_inner_model(model): |
| 76 | """Get the inner model (model.model or model.transformer).""" |
| 77 | for attr in ("model", "transformer"): |
| 78 | inner = getattr(model, attr, None) |
| 79 | if isinstance(inner, nn.Module): |
| 80 | # Some models have model.model (e.g. language_model.model) |
| 81 | inner_inner = getattr(inner, "model", None) |
| 82 | if isinstance(inner_inner, nn.Module): |
| 83 | return inner_inner |
| 84 | return inner |
| 85 | raise ValueError("Model must have a 'model' or 'transformer' attribute") |
| 86 | |
| 87 | |
| 88 | def get_layers(inner_model): |
no outgoing calls
no test coverage detected