MCPcopy
hub / github.com/e-p-armstrong/augmentoolkit / rag_server

Function rag_server

generation/utilities/rag_server/rag_server.py:629–1009  ·  view source on GitHub ↗
(
    prompt_path,
    template_path,
    gguf_model_path,
    context_length,
    documents_dir: str,
    questions_jsonl_path: str,
    question_chunk_size: int,
    top_k: int,
    cache_dir: str,
    max_shrink_iterations,
    collection_name: str = "questions_collection",
    llama_path="./llama.cpp",  # customizable llama.cpp path
    port=8003,
    task_id=None,
    **kwargs,
)

Source from the content-addressed store, hash-verified

627
628
629async def rag_server(
630 prompt_path,
631 template_path,
632 gguf_model_path,
633 context_length,
634 documents_dir: str,
635 questions_jsonl_path: str,
636 question_chunk_size: int,
637 top_k: int,
638 cache_dir: str,
639 max_shrink_iterations,
640 collection_name: str = "questions_collection",
641 llama_path="./llama.cpp", # customizable llama.cpp path
642 port=8003,
643 task_id=None,
644 **kwargs,
645):
646
647 with open(prompt_path, "r") as f:
648 prompt = f.read()
649
650 with open(template_path, "r") as f:
651 template = f.read()
652
653 # Get parent directory of gguf model path
654 model_dir = os.path.dirname(gguf_model_path)
655
656 # Initialize model-specific token counting
657 try:
658 count_tokens = count_tokens_specific_model(
659 model_dir
660 ) # Assumes that tokenizer and sucha re still inside the same dir as the saved gguf model.
661 except Exception as e:
662 print(e)
663 print(
664 "\n\nYou probably deleted the tokenizer and other ssuch things from the model directory after you got your quantized model. However, the model tokenizer is used to count tokens before sending it off to llama.cpp, so you shouldn't have done that. You can delete model files (*.safetensors) to save space, but leave the tokenizer alone."
665 )
666 print(
667 "To fix this, re-run the datagen pipeline that produced this model. It won't re-train, but it will re-download"
668 )
669 raise
670
671 # First thing's first, ragify the docs
672
673 app = FastAPI(
674 title="Augmentoolkit Custom Model RAG-Enabled API Server", version="0.1.0"
675 )
676
677 chroma_client = None
678 collection = None
679 bm25_index = None
680 bm25_corpus_data = []
681
682 # BM25 Persistence paths
683 BM25_CACHE_DIR = os.path.join(cache_dir, "bm25_cache")
684 BM25_INDEX_PATH = os.path.join(BM25_CACHE_DIR, "bm25_index.pkl")
685 BM25_CORPUS_DATA_PATH = os.path.join(BM25_CACHE_DIR, "bm25_corpus_data.json")
686

Callers 2

factual_datagen_fullFunction · 0.90
discord_inferenceFunction · 0.90

Calls 7

EngineWrapperClass · 0.90
get_stop_tokensFunction · 0.90
get_assistant_prefixFunction · 0.90
set_progressFunction · 0.90
vectorize_documentsFunction · 0.85
runMethod · 0.45

Tested by

no test coverage detected