codelion commited on
Commit
d4f6a15
Β·
verified Β·
1 Parent(s): 7783c17

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +114 -80
main.py CHANGED
@@ -5,54 +5,53 @@ import streamlit as st
5
  import anthropic
6
  from requests import JSONDecodeError
7
 
8
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
 
9
  from langchain_community.vectorstores import SupabaseVectorStore
10
- from langchain_community.llms import HuggingFaceEndpoint
11
- from langchain_community.chat_models import ChatOpenAI
12
 
13
- from langchain.chains import ConversationalRetrievalChain
14
  from langchain.memory import ConversationBufferMemory
 
 
 
 
15
 
16
  from supabase import Client, create_client
17
  from streamlit.logger import get_logger
18
  from stats import get_usage, add_usage
19
 
20
  # ─────── supabase + secrets ────────────────────────────────────────────────────
21
- supabase_url = st.secrets.SUPABASE_URL
22
- supabase_key = st.secrets.SUPABASE_KEY
23
- openai_api_key = st.secrets.openai_api_key
24
  anthropic_api_key = st.secrets.anthropic_api_key
25
- hf_api_key = st.secrets.hf_api_key
26
- username = st.secrets.username
27
 
28
  supabase: Client = create_client(supabase_url, supabase_key)
29
  logger = get_logger(__name__)
30
 
31
- # ─────── embeddings ─────────────────────────────────────────────────────────────
32
- # Switch to local BGE embeddings (no JSONDecode errors, no HTTP‑batch issues) :contentReference[oaicite:0]{index=0}
33
- embeddings = HuggingFaceBgeEmbeddings(
34
  model_name="BAAI/bge-large-en-v1.5",
35
  model_kwargs={"device": "cpu"},
36
  encode_kwargs={"normalize_embeddings": True}
37
  )
38
- # ─────── vector store + memory ─────────────────────────────────────────────────
 
39
  vector_store = SupabaseVectorStore(
40
  client=supabase,
41
  embedding=embeddings,
42
  query_name="match_documents",
43
  table_name="documents",
44
  )
45
- memory = ConversationBufferMemory(
46
- memory_key="chat_history",
47
- input_key="question",
48
- output_key="answer",
49
- return_messages=True,
50
- )
51
 
52
  # ─────── LLM setup ──────────────────────────────────────────────────────────────
53
- model = "HuggingFaceTB/SmolLM3-3B"
54
- temperature = 0.1
55
- max_tokens = 500
56
 
57
  import re
58
 
@@ -66,10 +65,10 @@ def clean_response(answer: str) -> str:
66
  answer = re.sub(r'<thinking>.*?</thinking>', '', answer, flags=re.DOTALL)
67
 
68
  # Remove other common AI response artifacts
69
- answer = re.sub(r'\[.*?\]', '', answer, flags=re.DOTALL) # Remove bracketed content
70
- answer = re.sub(r'\{.*?\}', '', answer, flags=re.DOTALL) # Remove curly bracketed content
71
- answer = re.sub(r'```.*?```', '', answer, flags=re.DOTALL) # Remove code blocks
72
- answer = re.sub(r'---.*?---', '', answer, flags=re.DOTALL) # Remove dashed sections
73
 
74
  # Remove excessive whitespace and newlines
75
  answer = re.sub(r'\s+', ' ', answer).strip()
@@ -79,65 +78,93 @@ def clean_response(answer: str) -> str:
79
  answer = re.sub(r'\s*(Sincerely,.*|Best regards,.*|Regards,.*)$', '', answer, flags=re.IGNORECASE)
80
 
81
  return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- def response_generator(query: str) -> str:
 
 
 
 
 
 
 
 
84
  """Ask the RAG chain to answer `query`, with JSON‑error fallback."""
85
  # log usage
86
  add_usage(supabase, "chat", "prompt:" + query, {"model": model, "temperature": temperature})
87
  logger.info("Using HF model %s", model)
88
 
89
- # prepare HF text-generation LLM
90
- # hf = HuggingFaceEndpoint(
91
- # # endpoint_url=f"https://api-inference.huggingface.co/models/{model}",
92
- # endpoint_url=f"https://router.huggingface.co/hf-inference/models/{model}",
93
- # task="text-generation",
94
- # huggingfacehub_api_token=hf_api_key,
95
- # model_kwargs={
96
- # "temperature": temperature,
97
- # "max_new_tokens": max_tokens,
98
- # "return_full_text": False,
99
- # },
100
- # )
101
-
102
- hf = ChatOpenAI(
103
- base_url=f"https://router.huggingface.co/hf-inference/models/{model}/v1",
104
- api_key=hf_api_key,
105
- model=model,
106
- temperature=temperature,
107
- max_tokens=max_tokens,
108
- timeout=30, # Add timeout
109
- max_retries=3, # Built-in retry logic
110
- )
111
-
112
- # conversational RAG chain
113
- qa = ConversationalRetrievalChain.from_llm(
114
- llm=hf,
115
- retriever=vector_store.as_retriever(
116
- search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}
117
- ),
118
- memory=memory,
119
- verbose=True,
120
- return_source_documents=True,
121
- )
122
 
123
  try:
124
- result = qa({"question": query})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  except JSONDecodeError as e:
126
- # fallback logging
127
- logger.error("Embedding JSONDecodeError: %s", e)
128
- return "Sorry, I had trouble understanding the embedded data. Please try again."
129
-
130
- answer = result.get("answer", "")
131
- sources = result.get("source_documents", [])
132
-
133
- if not sources:
134
- return (
135
- "I’m sorry, I don’t have enough information to answer that. "
136
- "If you have a public data source to add, please email copilot@securade.ai."
137
- )
138
-
139
- answer = clean_response(answer)
140
- return answer
141
 
142
  # ─────── Streamlit UI ──────────────────────────────────────────────────────────
143
  st.set_page_config(
@@ -161,23 +188,30 @@ st.markdown(
161
  "|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
162
  )
163
 
 
164
  if "chat_history" not in st.session_state:
165
  st.session_state.chat_history = []
166
 
167
- # show history
168
  for msg in st.session_state.chat_history:
169
  with st.chat_message(msg["role"]):
170
  st.markdown(msg["content"])
171
 
172
- # new user input
173
  if prompt := st.chat_input("Ask a question"):
 
174
  st.session_state.chat_history.append({"role": "user", "content": prompt})
 
 
175
  with st.chat_message("user"):
176
  st.markdown(prompt)
177
 
 
178
  with st.spinner("Safety briefing in progress..."):
179
- answer = response_generator(prompt)
180
 
181
  with st.chat_message("assistant"):
182
  st.markdown(answer)
183
- st.session_state.chat_history.append({"role": "assistant", "content": answer})
 
 
 
5
  import anthropic
6
  from requests import JSONDecodeError
7
 
8
+ # Updated imports for latest LangChain
9
+ from langchain_huggingface import HuggingFaceEmbeddings
10
  from langchain_community.vectorstores import SupabaseVectorStore
11
+ from langchain_openai import ChatOpenAI
12
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
13
 
14
+ # Updated memory and chain imports
15
  from langchain.memory import ConversationBufferMemory
16
+ from langchain.chains import create_retrieval_chain
17
+ from langchain.chains.combine_documents import create_stuff_documents_chain
18
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
19
+ from langchain_core.messages import HumanMessage, AIMessage
20
 
21
  from supabase import Client, create_client
22
  from streamlit.logger import get_logger
23
  from stats import get_usage, add_usage
24
 
25
  # ─────── supabase + secrets ────────────────────────────────────────────────────
26
+ supabase_url = st.secrets.SUPABASE_URL
27
+ supabase_key = st.secrets.SUPABASE_KEY
28
+ openai_api_key = st.secrets.openai_api_key
29
  anthropic_api_key = st.secrets.anthropic_api_key
30
+ hf_api_key = st.secrets.hf_api_key
31
+ username = st.secrets.username
32
 
33
  supabase: Client = create_client(supabase_url, supabase_key)
34
  logger = get_logger(__name__)
35
 
36
+ # ─────── embeddings (Updated to use langchain-huggingface) ─────────────────────
37
+ embeddings = HuggingFaceEmbeddings(
 
38
  model_name="BAAI/bge-large-en-v1.5",
39
  model_kwargs={"device": "cpu"},
40
  encode_kwargs={"normalize_embeddings": True}
41
  )
42
+
43
+ # ─────── vector store ──────────────────────────────────────────────────────────
44
  vector_store = SupabaseVectorStore(
45
  client=supabase,
46
  embedding=embeddings,
47
  query_name="match_documents",
48
  table_name="documents",
49
  )
 
 
 
 
 
 
50
 
51
  # ─────── LLM setup ──────────────────────────────────────────────────────────────
52
+ model = "HuggingFaceTB/SmolLM3-3B"
53
+ temperature = 0.1
54
+ max_tokens = 500
55
 
56
  import re
57
 
 
65
  answer = re.sub(r'<thinking>.*?</thinking>', '', answer, flags=re.DOTALL)
66
 
67
  # Remove other common AI response artifacts
68
+ answer = re.sub(r'\[.*?\]', '', answer, flags=re.DOTALL)
69
+ answer = re.sub(r'\{.*?\}', '', answer, flags=re.DOTALL)
70
+ answer = re.sub(r'```.*?```', '', answer, flags=re.DOTALL)
71
+ answer = re.sub(r'---.*?---', '', answer, flags=re.DOTALL)
72
 
73
  # Remove excessive whitespace and newlines
74
  answer = re.sub(r'\s+', ' ', answer).strip()
 
78
  answer = re.sub(r'\s*(Sincerely,.*|Best regards,.*|Regards,.*)$', '', answer, flags=re.IGNORECASE)
79
 
80
  return answer
81
+
82
+ def create_conversational_rag_chain():
83
+ """Create a modern conversational RAG chain using LCEL."""
84
+
85
+ # Create the HuggingFace LLM
86
+ llm = ChatOpenAI(
87
+ base_url=f"https://router.huggingface.co/hf-inference/models/{model}/v1",
88
+ api_key=hf_api_key,
89
+ model=model,
90
+ temperature=temperature,
91
+ max_tokens=max_tokens,
92
+ timeout=30,
93
+ max_retries=3,
94
+ )
95
+
96
+ # Create retriever
97
+ retriever = vector_store.as_retriever(
98
+ search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}
99
+ )
100
+
101
+ # Create system prompt for RAG
102
+ system_prompt = """You are a helpful safety assistant. Use the following pieces of retrieved context to answer the question.
103
+ If you don't know the answer based on the context, just say that you don't have enough information to answer that question.
104
+
105
+ Context: {context}
106
+
107
+ Chat History: {chat_history}
108
+
109
+ Question: {input}
110
+
111
+ Answer:"""
112
+
113
+ prompt = ChatPromptTemplate.from_messages([
114
+ ("system", system_prompt),
115
+ MessagesPlaceholder("chat_history"),
116
+ ("human", "{input}"),
117
+ ])
118
 
119
+ # Create document processing chain
120
+ question_answer_chain = create_stuff_documents_chain(llm, prompt)
121
+
122
+ # Create retrieval chain
123
+ rag_chain = create_retrieval_chain(retriever, question_answer_chain)
124
+
125
+ return rag_chain
126
+
127
+ def response_generator(query: str, chat_history: list) -> str:
128
  """Ask the RAG chain to answer `query`, with JSON‑error fallback."""
129
  # log usage
130
  add_usage(supabase, "chat", "prompt:" + query, {"model": model, "temperature": temperature})
131
  logger.info("Using HF model %s", model)
132
 
133
+ # Create the RAG chain
134
+ rag_chain = create_conversational_rag_chain()
135
+
136
+ # Format chat history for the chain
137
+ formatted_history = []
138
+ for msg in chat_history:
139
+ if msg["role"] == "user":
140
+ formatted_history.append(HumanMessage(content=msg["content"]))
141
+ elif msg["role"] == "assistant":
142
+ formatted_history.append(AIMessage(content=msg["content"]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  try:
145
+ result = rag_chain.invoke({
146
+ "input": query,
147
+ "chat_history": formatted_history
148
+ })
149
+
150
+ answer = result.get("answer", "")
151
+ context = result.get("context", [])
152
+
153
+ if not context:
154
+ return (
155
+ "I'm sorry, I don't have enough information to answer that. "
156
+ "If you have a public data source to add, please email copilot@securade.ai."
157
+ )
158
+
159
+ answer = clean_response(answer)
160
+ return answer
161
+
162
  except JSONDecodeError as e:
163
+ logger.error("JSONDecodeError: %s", e)
164
+ return "Sorry, I had trouble processing your request. Please try again."
165
+ except Exception as e:
166
+ logger.error("Unexpected error: %s", e)
167
+ return "Sorry, I encountered an error while processing your request. Please try again."
 
 
 
 
 
 
 
 
 
 
168
 
169
  # ─────── Streamlit UI ──────────────────────────────────────────────────────────
170
  st.set_page_config(
 
188
  "|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
189
  )
190
 
191
+ # Initialize chat history
192
  if "chat_history" not in st.session_state:
193
  st.session_state.chat_history = []
194
 
195
+ # Display chat history
196
  for msg in st.session_state.chat_history:
197
  with st.chat_message(msg["role"]):
198
  st.markdown(msg["content"])
199
 
200
+ # Handle new user input
201
  if prompt := st.chat_input("Ask a question"):
202
+ # Add user message to history
203
  st.session_state.chat_history.append({"role": "user", "content": prompt})
204
+
205
+ # Display user message
206
  with st.chat_message("user"):
207
  st.markdown(prompt)
208
 
209
+ # Generate and display response
210
  with st.spinner("Safety briefing in progress..."):
211
+ answer = response_generator(prompt, st.session_state.chat_history[:-1]) # Exclude current message
212
 
213
  with st.chat_message("assistant"):
214
  st.markdown(answer)
215
+
216
+ # Add assistant response to history
217
+ st.session_state.chat_history.append({"role": "assistant", "content": answer})