File size: 2,616 Bytes
b0fc7d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import re
import os
from llama_cpp import Llama, LlamaGrammar
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.postprocessor import LLMRerank

llm = LlamaCPP(
    model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
    temperature=0.1,
    max_new_tokens=256,
    context_window=16384
)
embedding_model = HuggingFaceEmbedding(
    model_name="models/all-MiniLM-L6-v2"
)
Settings.llm = llm
Settings.embed_model = embedding_model


def check_if_exists():
    index = os.path.exists("models/precomputed_index")
    bm25 = os.path.exists("models/bm25_retriever")
    if index and bm25:
        return True
    else:
        return False


def precompute_index(data_folder='data'):
    documents = SimpleDirectoryReader(data_folder).load_data()
    index = VectorStoreIndex.from_documents(documents)
    index.storage_context.persist(persist_dir='models/precomputed_index')
    bm25_retriever = BM25Retriever.from_defaults(
        nodes=documents,
        similarity_top_k=5
    )
    bm25_retriever.persist("models/bm25_retriever")

def is_harmful(query):
    harmful_keywords = ["bomb", "kill", "weapon", "suicide", "terror", "attack"]
    return any(keyword in query.lower() for keyword in harmful_keywords)


def answer_question(query):
    print("loading bm25 retriever")
    bm25_retriever = BM25Retriever.from_persist_dir("models/bm25_retriever")
    print("loading saved vector index")
    storage_context = StorageContext.from_defaults(persist_dir="models/precomputed_index")
    index = load_index_from_storage(storage_context)
    
    retriever = QueryFusionRetriever(
        [
            index.as_retriever(similarity_top_k=5),
            bm25_retriever,
        ],
        llm=llm,
        num_queries=1,
        similarity_top_k=5,
    )
    reranker = LLMRerank(
        choice_batch_size=5,
        top_n=5,
    )
    keyword_query_engine = RetrieverQueryEngine(
        retriever=retriever,
        node_postprocessors=[reranker]
    )
    
    if is_harmful(query):
        return "This query has been flagged as unsafe."
    
    response = keyword_query_engine.query(query)
    return str(response)