(model, layer_ids)
| 282 | |
| 283 | |
| 284 | def _patch_model(model, layer_ids): |
| 285 | if hasattr(model, "_hidden_states"): |
| 286 | return |
| 287 | model._hidden_states = [None] * len(layer_ids) |
| 288 | layers = _get_layers(model) |
| 289 | for i, lid in enumerate(layer_ids): |
| 290 | layers[lid] = _LayerHook(layers[lid], i, model._hidden_states) |
| 291 | |
| 292 | |
| 293 | class _GDNStateCapture: |
no test coverage detected