CalculateVRAM calculates the VRAM usage for a given model and configuration
(modelID string, bpw float64, context int, kvCacheQuant KVCacheQuantisation, ollamaModelInfo *OllamaModelInfo)
| 486 | |
| 487 | // CalculateVRAM calculates the VRAM usage for a given model and configuration |
| 488 | func CalculateVRAM(modelID string, bpw float64, context int, kvCacheQuant KVCacheQuantisation, ollamaModelInfo *OllamaModelInfo) (float64, error) { |
| 489 | logging.DebugLogger.Println("Calculating VRAM usage...") |
| 490 | |
| 491 | var config ModelConfig |
| 492 | var err error |
| 493 | |
| 494 | if ollamaModelInfo != nil { |
| 495 | // Use Ollama model information |
| 496 | paramCount, _ := extractModelInfo(ollamaModelInfo.ModelInfo, "parameter_count") |
| 497 | contextLength, _ := extractModelInfo(ollamaModelInfo.ModelInfo, "context_length") |
| 498 | blockCount, _ := extractModelInfo(ollamaModelInfo.ModelInfo, "block_count") |
| 499 | embeddingLength, _ := extractModelInfo(ollamaModelInfo.ModelInfo, "embedding_length") |
| 500 | headCountKV, _ := extractModelInfo(ollamaModelInfo.ModelInfo, "attention.head_count_kv") |
| 501 | headCount, _ := extractModelInfo(ollamaModelInfo.ModelInfo, "attention.head_count") |
| 502 | feedForwardLength, _ := extractModelInfo(ollamaModelInfo.ModelInfo, "feed_forward_length") |
| 503 | vocabSize, _ := extractModelInfo(ollamaModelInfo.ModelInfo, "vocab_size") |
| 504 | |
| 505 | config = ModelConfig{ |
| 506 | NumParams: paramCount / 1e9, // Convert to billions |
| 507 | MaxPositionEmbeddings: int(contextLength), |
| 508 | NumHiddenLayers: int(blockCount), |
| 509 | HiddenSize: int(embeddingLength), |
| 510 | NumKeyValueHeads: int(headCountKV), |
| 511 | NumAttentionHeads: int(headCount), |
| 512 | IntermediateSize: int(feedForwardLength), |
| 513 | VocabSize: int(vocabSize), |
| 514 | } |
| 515 | |
| 516 | // Estimate missing values |
| 517 | if config.HiddenSize == 0 { |
| 518 | config.HiddenSize = int(math.Sqrt(paramCount / 1000)) |
| 519 | } |
| 520 | if config.NumHiddenLayers == 0 { |
| 521 | config.NumHiddenLayers = int(math.Round(config.NumParams * 1e9 / (12 * float64(config.HiddenSize) * float64(config.HiddenSize)))) |
| 522 | } |
| 523 | if config.NumAttentionHeads == 0 { |
| 524 | config.NumAttentionHeads = config.HiddenSize / 64 // Assuming 64 dimension per head |
| 525 | } |
| 526 | if config.NumKeyValueHeads == 0 { |
| 527 | config.NumKeyValueHeads = config.NumAttentionHeads |
| 528 | } |
| 529 | if config.IntermediateSize == 0 { |
| 530 | config.IntermediateSize = 4 * config.HiddenSize |
| 531 | } |
| 532 | if config.VocabSize == 0 { |
| 533 | config.VocabSize = 32000 // A common default value |
| 534 | } |
| 535 | |
| 536 | // Parse BPW from quantisation level if not provided |
| 537 | if bpw == 0 { |
| 538 | bpw, err = ParseBPWOrQuant(ollamaModelInfo.Details.QuantizationLevel) |
| 539 | if err != nil { |
| 540 | return 0, fmt.Errorf("error parsing BPW from Ollama quantisation level: %v", err) |
| 541 | } |
| 542 | } |
| 543 | |
| 544 | logging.DebugLogger.Printf("Processed Ollama Model Config: %+v", config) |
| 545 | } else { |
no test coverage detected