snsynth commited on
Commit
b0fc7d6
·
1 Parent(s): ef898b6

revamp rag

Browse files
Files changed (3) hide show
  1. rag_app/chat_utils.py +10 -19
  2. rag_app/rag_2.py +79 -0
  3. requirements.txt +19 -1
rag_app/chat_utils.py CHANGED
@@ -3,8 +3,8 @@ import mesop as me
3
  from dataclasses import dataclass, field
4
  from typing import Callable, Generator, Literal
5
  import time
6
- from rag_app.rag import extract_final_answer, answer_question
7
-
8
 
9
  Role = Literal["user", "assistant"]
10
  _ROLE_USER = "user"
@@ -37,32 +37,21 @@ class State:
37
 
38
 
39
  def respond_to_chat(query: str, history: list[ChatMessage]):
 
 
 
 
40
  assistant_message = ChatMessage(role=_ROLE_ASSISTANT)
41
  yield assistant_message
42
  state = me.state(State)
43
  if len(state.pdf_files) == 0:
44
  response = answer_question(query)
45
  else:
46
- pdf_files = state.pdf_files
47
- response = extract_final_answer(pdf_files, query)
48
 
49
  print("Agent response=", response)
50
  yield response
51
 
52
- # messages = [{"role": message.role, "content": message.content} for message in history]
53
- # llm_response = llm.create_chat_completion(
54
- # messages=messages,
55
- # max_tokens=1024,
56
- # stop=[],
57
- # stream=True
58
- # )
59
- # assistant_message = ChatMessage(role=_ROLE_ASSISTANT)
60
- # yield assistant_message
61
- # for item in llm_response:
62
- # delta = item['choices'][0]['delta']
63
- # if 'content' in delta:
64
- # text = delta["content"]
65
- # yield text
66
 
67
  def on_chat_input(e: me.InputEvent):
68
  state = me.state(State)
@@ -129,7 +118,7 @@ def _make_chat_bubble_style(role: Role) -> me.Style:
129
 
130
 
131
  def save_uploaded_file(uploaded_file: me.UploadedFile):
132
- save_directory = "docs"
133
  os.makedirs(save_directory, exist_ok=True)
134
  file_path = os.path.join(save_directory, uploaded_file.name)
135
  with open(file_path, "wb") as f:
@@ -140,4 +129,6 @@ def save_uploaded_file(uploaded_file: me.UploadedFile):
140
  def handle_pdf_upload(event: me.UploadEvent):
141
  state = me.state(State)
142
  save_uploaded_file(event.file)
 
 
143
  state.pdf_files.append(os.path.join("docs", event.file.name))
 
3
  from dataclasses import dataclass, field
4
  from typing import Callable, Generator, Literal
5
  import time
6
+ # from rag_app.rag import extract_final_answer, answer_question
7
+ from rag_app.rag_2 import check_if_exists, precompute_index, answer_question
8
 
9
  Role = Literal["user", "assistant"]
10
  _ROLE_USER = "user"
 
37
 
38
 
39
  def respond_to_chat(query: str, history: list[ChatMessage]):
40
+ if not check_if_exists():
41
+ print("computing the vector index and the BM 25 retriever which will later be used")
42
+ precompute_index()
43
+
44
  assistant_message = ChatMessage(role=_ROLE_ASSISTANT)
45
  yield assistant_message
46
  state = me.state(State)
47
  if len(state.pdf_files) == 0:
48
  response = answer_question(query)
49
  else:
50
+ response = answer_question(query)
 
51
 
52
  print("Agent response=", response)
53
  yield response
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def on_chat_input(e: me.InputEvent):
57
  state = me.state(State)
 
118
 
119
 
120
  def save_uploaded_file(uploaded_file: me.UploadedFile):
121
+ save_directory = "data"
122
  os.makedirs(save_directory, exist_ok=True)
123
  file_path = os.path.join(save_directory, uploaded_file.name)
124
  with open(file_path, "wb") as f:
 
129
  def handle_pdf_upload(event: me.UploadEvent):
130
  state = me.state(State)
131
  save_uploaded_file(event.file)
132
+ print("precomputing vector indices")
133
+ precompute_index()
134
  state.pdf_files.append(os.path.join("docs", event.file.name))
rag_app/rag_2.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ from llama_cpp import Llama, LlamaGrammar
4
+ from llama_index.llms.llama_cpp import LlamaCPP
5
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
6
+ from llama_index.retrievers.bm25 import BM25Retriever
7
+ from llama_index.core.retrievers import QueryFusionRetriever
8
+ from llama_index.core.query_engine import RetrieverQueryEngine
9
+ from llama_index.core import StorageContext, load_index_from_storage
10
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
11
+ from llama_index.core.postprocessor import LLMRerank
12
+
13
+ llm = LlamaCPP(
14
+ model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
15
+ temperature=0.1,
16
+ max_new_tokens=256,
17
+ context_window=16384
18
+ )
19
+ embedding_model = HuggingFaceEmbedding(
20
+ model_name="models/all-MiniLM-L6-v2"
21
+ )
22
+ Settings.llm = llm
23
+ Settings.embed_model = embedding_model
24
+
25
+
26
+ def check_if_exists():
27
+ index = os.path.exists("models/precomputed_index")
28
+ bm25 = os.path.exists("models/bm25_retriever")
29
+ if index and bm25:
30
+ return True
31
+ else:
32
+ return False
33
+
34
+
35
+ def precompute_index(data_folder='data'):
36
+ documents = SimpleDirectoryReader(data_folder).load_data()
37
+ index = VectorStoreIndex.from_documents(documents)
38
+ index.storage_context.persist(persist_dir='models/precomputed_index')
39
+ bm25_retriever = BM25Retriever.from_defaults(
40
+ nodes=documents,
41
+ similarity_top_k=5
42
+ )
43
+ bm25_retriever.persist("models/bm25_retriever")
44
+
45
+ def is_harmful(query):
46
+ harmful_keywords = ["bomb", "kill", "weapon", "suicide", "terror", "attack"]
47
+ return any(keyword in query.lower() for keyword in harmful_keywords)
48
+
49
+
50
+ def answer_question(query):
51
+ print("loading bm25 retriever")
52
+ bm25_retriever = BM25Retriever.from_persist_dir("models/bm25_retriever")
53
+ print("loading saved vector index")
54
+ storage_context = StorageContext.from_defaults(persist_dir="models/precomputed_index")
55
+ index = load_index_from_storage(storage_context)
56
+
57
+ retriever = QueryFusionRetriever(
58
+ [
59
+ index.as_retriever(similarity_top_k=5),
60
+ bm25_retriever,
61
+ ],
62
+ llm=llm,
63
+ num_queries=1,
64
+ similarity_top_k=5,
65
+ )
66
+ reranker = LLMRerank(
67
+ choice_batch_size=5,
68
+ top_n=5,
69
+ )
70
+ keyword_query_engine = RetrieverQueryEngine(
71
+ retriever=retriever,
72
+ node_postprocessors=[reranker]
73
+ )
74
+
75
+ if is_harmful(query):
76
+ return "This query has been flagged as unsafe."
77
+
78
+ response = keyword_query_engine.query(query)
79
+ return str(response)
requirements.txt CHANGED
@@ -12,4 +12,22 @@ pdfplumber
12
  pypdf2
13
  torch==2.6.0
14
  torchaudio==2.6.0
15
- torchvision==0.21.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  pypdf2
13
  torch==2.6.0
14
  torchaudio==2.6.0
15
+ torchvision==0.21.0
16
+ llama-index==0.12.24
17
+ llama-index-agent-openai==0.4.6
18
+ llama-index-cli==0.4.1
19
+ llama-index-core==0.12.24.post1
20
+ llama-index-embeddings-huggingface==0.5.2
21
+ llama-index-embeddings-openai==0.3.1
22
+ llama-index-indices-managed-llama-cloud==0.6.9
23
+ llama-index-llms-llama-cpp==0.4.0
24
+ llama-index-llms-openai==0.3.25
25
+ llama-index-multi-modal-llms-openai==0.4.3
26
+ llama-index-postprocessor-cohere-rerank==0.3.0
27
+ llama-index-postprocessor-colbert-rerank==0.3.0
28
+ llama-index-program-openai==0.3.1
29
+ llama-index-question-gen-openai==0.3.0
30
+ llama-index-readers-file==0.4.6
31
+ llama-index-readers-llama-parse==0.4.0
32
+ llama-index-retrievers-bm25==0.5.2
33
+ llama-parse==0.6.4.post1