Get the context length of a model from a huggingface model config.
(config)
| 330 | |
| 331 | |
| 332 | def get_context_length(config): |
| 333 | """Get the context length of a model from a huggingface model config.""" |
| 334 | rope_scaling = getattr(config, "rope_scaling", None) |
| 335 | if rope_scaling: |
| 336 | rope_scaling_factor = config.rope_scaling["factor"] |
| 337 | else: |
| 338 | rope_scaling_factor = 1 |
| 339 | |
| 340 | for key in SEQUENCE_LENGTH_KEYS: |
| 341 | val = getattr(config, key, None) |
| 342 | if val is not None: |
| 343 | return int(rope_scaling_factor * val) |
| 344 | return 2048 |
| 345 | |
| 346 | |
| 347 | def str_to_torch_dtype(dtype: str): |
no outgoing calls
no test coverage detected
searching dependent graphs…