revamp rag
Browse files- rag_app/chat_utils.py +10 -19
- rag_app/rag_2.py +79 -0
- requirements.txt +19 -1
rag_app/chat_utils.py
CHANGED
@@ -3,8 +3,8 @@ import mesop as me
|
|
3 |
from dataclasses import dataclass, field
|
4 |
from typing import Callable, Generator, Literal
|
5 |
import time
|
6 |
-
from rag_app.rag import extract_final_answer, answer_question
|
7 |
-
|
8 |
|
9 |
Role = Literal["user", "assistant"]
|
10 |
_ROLE_USER = "user"
|
@@ -37,32 +37,21 @@ class State:
|
|
37 |
|
38 |
|
39 |
def respond_to_chat(query: str, history: list[ChatMessage]):
|
|
|
|
|
|
|
|
|
40 |
assistant_message = ChatMessage(role=_ROLE_ASSISTANT)
|
41 |
yield assistant_message
|
42 |
state = me.state(State)
|
43 |
if len(state.pdf_files) == 0:
|
44 |
response = answer_question(query)
|
45 |
else:
|
46 |
-
|
47 |
-
response = extract_final_answer(pdf_files, query)
|
48 |
|
49 |
print("Agent response=", response)
|
50 |
yield response
|
51 |
|
52 |
-
# messages = [{"role": message.role, "content": message.content} for message in history]
|
53 |
-
# llm_response = llm.create_chat_completion(
|
54 |
-
# messages=messages,
|
55 |
-
# max_tokens=1024,
|
56 |
-
# stop=[],
|
57 |
-
# stream=True
|
58 |
-
# )
|
59 |
-
# assistant_message = ChatMessage(role=_ROLE_ASSISTANT)
|
60 |
-
# yield assistant_message
|
61 |
-
# for item in llm_response:
|
62 |
-
# delta = item['choices'][0]['delta']
|
63 |
-
# if 'content' in delta:
|
64 |
-
# text = delta["content"]
|
65 |
-
# yield text
|
66 |
|
67 |
def on_chat_input(e: me.InputEvent):
|
68 |
state = me.state(State)
|
@@ -129,7 +118,7 @@ def _make_chat_bubble_style(role: Role) -> me.Style:
|
|
129 |
|
130 |
|
131 |
def save_uploaded_file(uploaded_file: me.UploadedFile):
|
132 |
-
save_directory = "
|
133 |
os.makedirs(save_directory, exist_ok=True)
|
134 |
file_path = os.path.join(save_directory, uploaded_file.name)
|
135 |
with open(file_path, "wb") as f:
|
@@ -140,4 +129,6 @@ def save_uploaded_file(uploaded_file: me.UploadedFile):
|
|
140 |
def handle_pdf_upload(event: me.UploadEvent):
|
141 |
state = me.state(State)
|
142 |
save_uploaded_file(event.file)
|
|
|
|
|
143 |
state.pdf_files.append(os.path.join("docs", event.file.name))
|
|
|
3 |
from dataclasses import dataclass, field
|
4 |
from typing import Callable, Generator, Literal
|
5 |
import time
|
6 |
+
# from rag_app.rag import extract_final_answer, answer_question
|
7 |
+
from rag_app.rag_2 import check_if_exists, precompute_index, answer_question
|
8 |
|
9 |
Role = Literal["user", "assistant"]
|
10 |
_ROLE_USER = "user"
|
|
|
37 |
|
38 |
|
39 |
def respond_to_chat(query: str, history: list[ChatMessage]):
|
40 |
+
if not check_if_exists():
|
41 |
+
print("computing the vector index and the BM 25 retriever which will later be used")
|
42 |
+
precompute_index()
|
43 |
+
|
44 |
assistant_message = ChatMessage(role=_ROLE_ASSISTANT)
|
45 |
yield assistant_message
|
46 |
state = me.state(State)
|
47 |
if len(state.pdf_files) == 0:
|
48 |
response = answer_question(query)
|
49 |
else:
|
50 |
+
response = answer_question(query)
|
|
|
51 |
|
52 |
print("Agent response=", response)
|
53 |
yield response
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def on_chat_input(e: me.InputEvent):
|
57 |
state = me.state(State)
|
|
|
118 |
|
119 |
|
120 |
def save_uploaded_file(uploaded_file: me.UploadedFile):
|
121 |
+
save_directory = "data"
|
122 |
os.makedirs(save_directory, exist_ok=True)
|
123 |
file_path = os.path.join(save_directory, uploaded_file.name)
|
124 |
with open(file_path, "wb") as f:
|
|
|
129 |
def handle_pdf_upload(event: me.UploadEvent):
|
130 |
state = me.state(State)
|
131 |
save_uploaded_file(event.file)
|
132 |
+
print("precomputing vector indices")
|
133 |
+
precompute_index()
|
134 |
state.pdf_files.append(os.path.join("docs", event.file.name))
|
rag_app/rag_2.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
from llama_cpp import Llama, LlamaGrammar
|
4 |
+
from llama_index.llms.llama_cpp import LlamaCPP
|
5 |
+
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
|
6 |
+
from llama_index.retrievers.bm25 import BM25Retriever
|
7 |
+
from llama_index.core.retrievers import QueryFusionRetriever
|
8 |
+
from llama_index.core.query_engine import RetrieverQueryEngine
|
9 |
+
from llama_index.core import StorageContext, load_index_from_storage
|
10 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
11 |
+
from llama_index.core.postprocessor import LLMRerank
|
12 |
+
|
13 |
+
llm = LlamaCPP(
|
14 |
+
model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
|
15 |
+
temperature=0.1,
|
16 |
+
max_new_tokens=256,
|
17 |
+
context_window=16384
|
18 |
+
)
|
19 |
+
embedding_model = HuggingFaceEmbedding(
|
20 |
+
model_name="models/all-MiniLM-L6-v2"
|
21 |
+
)
|
22 |
+
Settings.llm = llm
|
23 |
+
Settings.embed_model = embedding_model
|
24 |
+
|
25 |
+
|
26 |
+
def check_if_exists():
|
27 |
+
index = os.path.exists("models/precomputed_index")
|
28 |
+
bm25 = os.path.exists("models/bm25_retriever")
|
29 |
+
if index and bm25:
|
30 |
+
return True
|
31 |
+
else:
|
32 |
+
return False
|
33 |
+
|
34 |
+
|
35 |
+
def precompute_index(data_folder='data'):
|
36 |
+
documents = SimpleDirectoryReader(data_folder).load_data()
|
37 |
+
index = VectorStoreIndex.from_documents(documents)
|
38 |
+
index.storage_context.persist(persist_dir='models/precomputed_index')
|
39 |
+
bm25_retriever = BM25Retriever.from_defaults(
|
40 |
+
nodes=documents,
|
41 |
+
similarity_top_k=5
|
42 |
+
)
|
43 |
+
bm25_retriever.persist("models/bm25_retriever")
|
44 |
+
|
45 |
+
def is_harmful(query):
|
46 |
+
harmful_keywords = ["bomb", "kill", "weapon", "suicide", "terror", "attack"]
|
47 |
+
return any(keyword in query.lower() for keyword in harmful_keywords)
|
48 |
+
|
49 |
+
|
50 |
+
def answer_question(query):
|
51 |
+
print("loading bm25 retriever")
|
52 |
+
bm25_retriever = BM25Retriever.from_persist_dir("models/bm25_retriever")
|
53 |
+
print("loading saved vector index")
|
54 |
+
storage_context = StorageContext.from_defaults(persist_dir="models/precomputed_index")
|
55 |
+
index = load_index_from_storage(storage_context)
|
56 |
+
|
57 |
+
retriever = QueryFusionRetriever(
|
58 |
+
[
|
59 |
+
index.as_retriever(similarity_top_k=5),
|
60 |
+
bm25_retriever,
|
61 |
+
],
|
62 |
+
llm=llm,
|
63 |
+
num_queries=1,
|
64 |
+
similarity_top_k=5,
|
65 |
+
)
|
66 |
+
reranker = LLMRerank(
|
67 |
+
choice_batch_size=5,
|
68 |
+
top_n=5,
|
69 |
+
)
|
70 |
+
keyword_query_engine = RetrieverQueryEngine(
|
71 |
+
retriever=retriever,
|
72 |
+
node_postprocessors=[reranker]
|
73 |
+
)
|
74 |
+
|
75 |
+
if is_harmful(query):
|
76 |
+
return "This query has been flagged as unsafe."
|
77 |
+
|
78 |
+
response = keyword_query_engine.query(query)
|
79 |
+
return str(response)
|
requirements.txt
CHANGED
@@ -12,4 +12,22 @@ pdfplumber
|
|
12 |
pypdf2
|
13 |
torch==2.6.0
|
14 |
torchaudio==2.6.0
|
15 |
-
torchvision==0.21.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
pypdf2
|
13 |
torch==2.6.0
|
14 |
torchaudio==2.6.0
|
15 |
+
torchvision==0.21.0
|
16 |
+
llama-index==0.12.24
|
17 |
+
llama-index-agent-openai==0.4.6
|
18 |
+
llama-index-cli==0.4.1
|
19 |
+
llama-index-core==0.12.24.post1
|
20 |
+
llama-index-embeddings-huggingface==0.5.2
|
21 |
+
llama-index-embeddings-openai==0.3.1
|
22 |
+
llama-index-indices-managed-llama-cloud==0.6.9
|
23 |
+
llama-index-llms-llama-cpp==0.4.0
|
24 |
+
llama-index-llms-openai==0.3.25
|
25 |
+
llama-index-multi-modal-llms-openai==0.4.3
|
26 |
+
llama-index-postprocessor-cohere-rerank==0.3.0
|
27 |
+
llama-index-postprocessor-colbert-rerank==0.3.0
|
28 |
+
llama-index-program-openai==0.3.1
|
29 |
+
llama-index-question-gen-openai==0.3.0
|
30 |
+
llama-index-readers-file==0.4.6
|
31 |
+
llama-index-readers-llama-parse==0.4.0
|
32 |
+
llama-index-retrievers-bm25==0.5.2
|
33 |
+
llama-parse==0.6.4.post1
|