Spaces:
Sleeping
Sleeping
import re | |
import os | |
from llama_cpp import Llama, LlamaGrammar | |
from llama_index.llms.llama_cpp import LlamaCPP | |
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings | |
from llama_index.retrievers.bm25 import BM25Retriever | |
from llama_index.core.retrievers import QueryFusionRetriever | |
from llama_index.core.query_engine import RetrieverQueryEngine | |
from llama_index.core import StorageContext, load_index_from_storage | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.core.postprocessor import LLMRerank | |
llm = LlamaCPP( | |
model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf", | |
temperature=0.1, | |
max_new_tokens=256, | |
context_window=16384 | |
) | |
embedding_model = HuggingFaceEmbedding( | |
model_name="models/all-MiniLM-L6-v2" | |
) | |
Settings.llm = llm | |
Settings.embed_model = embedding_model | |
def check_if_exists(): | |
index = os.path.exists("models/precomputed_index") | |
bm25 = os.path.exists("models/bm25_retriever") | |
if index and bm25: | |
return True | |
else: | |
return False | |
def precompute_index(data_folder='data'): | |
documents = SimpleDirectoryReader(data_folder).load_data() | |
index = VectorStoreIndex.from_documents(documents) | |
index.storage_context.persist(persist_dir='models/precomputed_index') | |
bm25_retriever = BM25Retriever.from_defaults( | |
nodes=documents, | |
similarity_top_k=5 | |
) | |
bm25_retriever.persist("models/bm25_retriever") | |
def is_harmful(query): | |
harmful_keywords = ["bomb", "kill", "weapon", "suicide", "terror", "attack"] | |
return any(keyword in query.lower() for keyword in harmful_keywords) | |
def answer_question(query): | |
print("loading bm25 retriever") | |
bm25_retriever = BM25Retriever.from_persist_dir("models/bm25_retriever") | |
print("loading saved vector index") | |
storage_context = StorageContext.from_defaults(persist_dir="models/precomputed_index") | |
index = load_index_from_storage(storage_context) | |
retriever = QueryFusionRetriever( | |
[ | |
index.as_retriever(similarity_top_k=5), | |
bm25_retriever, | |
], | |
llm=llm, | |
num_queries=1, | |
similarity_top_k=5, | |
) | |
reranker = LLMRerank( | |
choice_batch_size=5, | |
top_n=5, | |
) | |
keyword_query_engine = RetrieverQueryEngine( | |
retriever=retriever, | |
node_postprocessors=[reranker] | |
) | |
if is_harmful(query): | |
return "This query has been flagged as unsafe." | |
response = keyword_query_engine.query(query) | |
return str(response) | |