snsynth's picture
revamp rag
b0fc7d6
raw
history blame
2.62 kB
import re
import os
from llama_cpp import Llama, LlamaGrammar
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.postprocessor import LLMRerank
llm = LlamaCPP(
model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
temperature=0.1,
max_new_tokens=256,
context_window=16384
)
embedding_model = HuggingFaceEmbedding(
model_name="models/all-MiniLM-L6-v2"
)
Settings.llm = llm
Settings.embed_model = embedding_model
def check_if_exists():
index = os.path.exists("models/precomputed_index")
bm25 = os.path.exists("models/bm25_retriever")
if index and bm25:
return True
else:
return False
def precompute_index(data_folder='data'):
documents = SimpleDirectoryReader(data_folder).load_data()
index = VectorStoreIndex.from_documents(documents)
index.storage_context.persist(persist_dir='models/precomputed_index')
bm25_retriever = BM25Retriever.from_defaults(
nodes=documents,
similarity_top_k=5
)
bm25_retriever.persist("models/bm25_retriever")
def is_harmful(query):
harmful_keywords = ["bomb", "kill", "weapon", "suicide", "terror", "attack"]
return any(keyword in query.lower() for keyword in harmful_keywords)
def answer_question(query):
print("loading bm25 retriever")
bm25_retriever = BM25Retriever.from_persist_dir("models/bm25_retriever")
print("loading saved vector index")
storage_context = StorageContext.from_defaults(persist_dir="models/precomputed_index")
index = load_index_from_storage(storage_context)
retriever = QueryFusionRetriever(
[
index.as_retriever(similarity_top_k=5),
bm25_retriever,
],
llm=llm,
num_queries=1,
similarity_top_k=5,
)
reranker = LLMRerank(
choice_batch_size=5,
top_n=5,
)
keyword_query_engine = RetrieverQueryEngine(
retriever=retriever,
node_postprocessors=[reranker]
)
if is_harmful(query):
return "This query has been flagged as unsafe."
response = keyword_query_engine.query(query)
return str(response)