Estimate context size based on the model name. EXPERIMENTAL for memory compression
(self, model_name: str)
| 44 | self.download_model() |
| 45 | |
| 46 | def get_ideal_ctx(self, model_name: str) -> int | None: |
| 47 | """ |
| 48 | Estimate context size based on the model name. |
| 49 | EXPERIMENTAL for memory compression |
| 50 | """ |
| 51 | import re |
| 52 | import math |
| 53 | |
| 54 | def extract_number_before_b(sentence: str) -> int: |
| 55 | match = re.search(r'(\d+)b', sentence, re.IGNORECASE) |
| 56 | return int(match.group(1)) if match else None |
| 57 | |
| 58 | model_size = extract_number_before_b(model_name) |
| 59 | if not model_size: |
| 60 | return None |
| 61 | base_size = 7 # Base model size in billions |
| 62 | base_context = 4096 # Base context size in tokens |
| 63 | scaling_factor = 1.5 # Approximate scaling factor for context size growth |
| 64 | context_size = int(base_context * (model_size / base_size) ** scaling_factor) |
| 65 | context_size = 2 ** round(math.log2(context_size)) |
| 66 | self.logger.info(f"Estimated context size for {model_name}: {context_size} tokens.") |
| 67 | return context_size |
| 68 | |
| 69 | def download_model(self): |
| 70 | """Download the model if not already downloaded.""" |
no test coverage detected