Ritesh-hf commited on
Commit
244fa56
·
verified ·
1 Parent(s): f5e3372

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +376 -205
app.py CHANGED
@@ -1,245 +1,416 @@
 
 
 
 
 
 
1
  import nltk
2
- nltk.download('punkt_tab')
 
 
 
 
3
 
4
- import os
5
  from dotenv import load_dotenv
6
- import asyncio
7
- from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
8
- from fastapi.responses import HTMLResponse
9
- from fastapi.templating import Jinja2Templates
10
  from fastapi.middleware.cors import CORSMiddleware
11
- from langchain.chains import create_history_aware_retriever, create_retrieval_chain
12
- from langchain.chains.combine_documents import create_stuff_documents_chain
13
- from langchain_community.chat_message_histories import ChatMessageHistory
14
- from langchain_core.chat_history import BaseChatMessageHistory
15
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
16
- from langchain_core.runnables.history import RunnableWithMessageHistory
17
  from pinecone import Pinecone
18
  from pinecone_text.sparse import BM25Encoder
19
- from langchain_huggingface import HuggingFaceEmbeddings
20
  from langchain_community.retrievers import PineconeHybridSearchRetriever
21
- from langchain.retrievers import ContextualCompressionRetriever
22
- from langchain_community.chat_models import ChatPerplexity
23
- from langchain.retrievers.document_compressors import CrossEncoderReranker
24
- from langchain_community.cross_encoders import HuggingFaceCrossEncoder
25
- from langchain_core.prompts import PromptTemplate
26
- from langchain.retrievers.document_compressors import FlashrankRerank
27
- import re
28
 
29
- # Load environment variables
 
 
 
30
  load_dotenv(".env")
31
- USER_AGENT = os.getenv("USER_AGENT")
32
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
33
- SECRET_KEY = os.getenv("SECRET_KEY")
34
- PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
35
- SESSION_ID_DEFAULT = "abc123"
36
-
37
- # Set environment variables
38
- os.environ['USER_AGENT'] = USER_AGENT
39
- os.environ["GROQ_API_KEY"] = GROQ_API_KEY
40
- os.environ["TOKENIZERS_PARALLELISM"] = 'true'
41
-
42
- # Initialize FastAPI app and CORS
43
- app = FastAPI()
44
- origins = ["*"] # Adjust as needed
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  app.add_middleware(
47
  CORSMiddleware,
48
- allow_origins=origins,
49
  allow_credentials=True,
50
  allow_methods=["*"],
51
  allow_headers=["*"],
52
  )
53
 
54
- templates = Jinja2Templates(directory="templates")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # Function to initialize Pinecone connection
57
- def initialize_pinecone(index_name: str):
58
- try:
59
- pc = Pinecone(api_key=PINECONE_API_KEY)
60
- return pc.Index(index_name)
61
- except Exception as e:
62
- print(f"Error initializing Pinecone: {e}")
63
- raise
64
 
 
65
 
66
- ##################################################
67
- ## Change down here
68
- ##################################################
69
- # #### This is for UAE Economic Department Website
70
- pinecone_index = initialize_pinecone("updated-saudi-arabia-ministry-of-justice")
71
- bm25 = BM25Encoder().load("./updated-saudi-arabia-bm25-encoder.json")
72
- ##################################################
73
- ##################################################
74
-
75
- # Initialize models and retriever
76
- embed_model = HuggingFaceEmbeddings(model_name="jinaai/jina-embeddings-v3", model_kwargs={"trust_remote_code":True})
77
- retriever = PineconeHybridSearchRetriever(
78
- embeddings=embed_model,
79
- sparse_encoder=bm25,
80
- index=pinecone_index,
81
- top_k=10,
82
- alpha=0.5,
83
- )
84
 
85
- # Initialize LLM
86
- llm = ChatPerplexity(temperature=0, pplx_api_key=GROQ_API_KEY, model="llama-3.1-sonar-large-128k-chat", max_tokens=512, max_retries=2)
87
-
88
- # Initialize Reranker
89
- # model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
90
- # compressor = CrossEncoderReranker(model=model, top_n=10)
91
-
92
- # compression_retriever = ContextualCompressionRetriever(
93
- # base_compressor=compressor, base_retriever=retriever
94
- # )
95
- # from langchain.retrievers.document_compressors import LLMChainExtractor
96
-
97
- # compressor = LLMChainExtractor.from_llm(llm)
98
- # compression_retriever = ContextualCompressionRetriever(
99
- # base_compressor=compressor, base_retriever=retriever
100
- # )
101
-
102
- # compressor = FlashrankRerank(top_n=10)
103
- # compression_retriever = ContextualCompressionRetriever(
104
- # base_compressor=compressor, base_retriever=retriever
105
- # )
106
-
107
- # Contextualization prompt and retriever
108
- contextualize_q_system_prompt = """ Given a chat history and the latest user question \
109
- which might reference context in the chat history, formulate a standalone question \
110
- which can be understood without the chat history. Do NOT answer the question, \
111
- just reformulate it if needed and otherwise return it as is.
112
- """
113
- contextualize_q_prompt = ChatPromptTemplate.from_messages(
114
- [
115
- ("system", contextualize_q_system_prompt),
116
- MessagesPlaceholder("chat_history"),
117
- ("human", "{input}")
118
- ]
119
- )
120
- history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
121
 
122
- # QA system prompt and chain
123
- qa_system_prompt = """ You are a highly skilled information retrieval assistant. Use the following context to answer questions effectively.
124
- If you don't know the answer, simply state that you don't know.
125
 
126
- YOUR ANSWER SHOULD BE IN '{language}' LANGUAGE.
 
 
 
127
 
128
- When responding to queries, follow these guidelines:
 
 
129
 
130
- 1. Provide Clear Answers:
131
- - You have to answer in that language based on the given language of the answer. If it is English, answer it in English; if it is Arabic, you should answer it in Arabic.
132
- - Ensure the response directly addresses the query with accurate and relevant information.
133
- - Do not give long answers. Provide detailed but concise responses.
134
-
135
- 2. Formatting for Readability:
136
- - Provide the entire response in proper markdown format.
137
- - Use structured Markdown elements such as headings, subheadings, lists, tables, and links.
138
- - Use emphasis on headings, important texts, and phrases.
139
-
140
- 3. Proper Citations:
141
- - Always use inline citations with embedded source URLs.
142
- - The inline citations should be in the format [1], [2], etc.
143
- - DO NOT INCLUDE THE 'References' SECTION IN THE RESPONSE.
144
 
145
- FOLLOW ALL THE GIVEN INSTRUCTIONS, FAILURE TO DO SO WILL RESULT IN THE TERMINATION OF THE CHAT.
 
 
 
 
146
 
147
- == CONTEXT ==
 
 
 
 
 
 
 
 
148
 
149
- {context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  """
151
- qa_prompt = ChatPromptTemplate.from_messages(
152
- [
153
- ("system", qa_system_prompt),
154
- MessagesPlaceholder("chat_history"),
155
- ("human", "{input}")
156
- ]
157
- )
158
 
159
- document_prompt = PromptTemplate(input_variables=["page_content"], template="{page_content}")
160
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt, document_prompt=document_prompt)
 
 
 
 
 
161
 
162
- # Retrieval and Generative (RAG) Chain
163
- rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
 
164
 
165
- # Chat message history storage
166
- store = {}
 
 
167
 
168
- def get_session_history(session_id: str) -> BaseChatMessageHistory:
169
- if session_id not in store:
170
- store[session_id] = ChatMessageHistory()
171
- return store[session_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- # Conversational RAG chain with message history
174
- conversational_rag_chain = RunnableWithMessageHistory(
175
- rag_chain,
176
- get_session_history,
177
- input_messages_key="input",
178
- history_messages_key="chat_history",
179
- language_message_key="language",
180
- output_messages_key="answer",
181
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
 
 
183
 
184
- # WebSocket endpoint with streaming
185
- @app.websocket("/ws")
 
 
 
 
 
 
 
 
 
 
 
186
  async def websocket_endpoint(websocket: WebSocket):
187
  await websocket.accept()
188
- print(f"Client connected: {websocket.client}")
189
- session_id = None
190
  try:
191
- while True:
192
- data = await websocket.receive_json()
193
- question = data.get('question')
194
- language = data.get('language')
195
- if "en" in language:
196
- language = "English"
197
- else:
198
- language = "Arabic"
199
- session_id = data.get('session_id', SESSION_ID_DEFAULT)
200
- # Process the question
201
- try:
202
- # Define an async generator for streaming
203
- async def stream_response():
204
- complete_response = ""
205
- context = {}
206
- async for chunk in conversational_rag_chain.astream(
207
- {"input": question, 'language': language},
208
- config={"configurable": {"session_id": session_id}}
209
- ):
210
- if "context" in chunk:
211
- context = chunk['context']
212
- # Send each chunk to the client
213
- if "answer" in chunk:
214
- complete_response += chunk['answer']
215
- await websocket.send_json({'response': chunk['answer']})
216
- if context:
217
- citations = re.findall(r'\[(\d+)\]', complete_response)
218
- citation_numbers = list(map(int, citations))
219
- sources = dict()
220
- backup = dict()
221
- i=1
222
- for index, doc in enumerate(context):
223
- if (index+1) in citation_numbers:
224
- sources[f"[{index+1}]"] = doc.metadata["source"]
225
- else:
226
- if doc.metadata["source"] not in backup.values():
227
- backup[f"[{i}]"] = doc.metadata["source"]
228
- i += 1
229
- if sources:
230
- await websocket.send_json({'sources': sources})
231
- else:
232
- await websocket.send_json({'sources': backup})
233
- await stream_response()
234
- except Exception as e:
235
- print(f"Error during message handling: {e}")
236
- await websocket.send_json({'response': "Something went wrong, Please try again.." + str(e)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  except WebSocketDisconnect:
238
- print(f"Client disconnected: {websocket.client}")
239
- if session_id:
240
- store.pop(session_id, None)
241
-
242
- # Home route
243
- @app.get("/", response_class=HTMLResponse)
244
- async def read_index(request: Request):
245
- return templates.TemplateResponse("chat.html", {"request": request})
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import re
4
+ import time
5
+ import logging
6
+
7
  import nltk
8
+ # Pre-download the required nltk resource if not already available.
9
+ try:
10
+ nltk.data.find('tokenizers/punkt_tab')
11
+ except LookupError:
12
+ nltk.download('punkt_tab')
13
 
 
14
  from dotenv import load_dotenv
15
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
 
 
 
16
  from fastapi.middleware.cors import CORSMiddleware
17
+ from fastapi.responses import JSONResponse
18
+ from pydantic import BaseModel, Field, ValidationError
19
+ from typing import List, Dict, Tuple
20
+
 
 
21
  from pinecone import Pinecone
22
  from pinecone_text.sparse import BM25Encoder
 
23
  from langchain_community.retrievers import PineconeHybridSearchRetriever
24
+ from langchain_huggingface import HuggingFaceEmbeddings
25
+ from openai import AsyncOpenAI
 
 
 
 
 
26
 
27
+
28
+ # ------------------------------------------------------------------------------
29
+ # Load environment variables and validate required ones
30
+ # ------------------------------------------------------------------------------
31
  load_dotenv(".env")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ required_env_vars = [
34
+ "PINECONE_API_KEY",
35
+ "PERPLEXITY_API_KEY",
36
+ "OPENAI_API_KEY" # Ensure the OpenAI API key is provided
37
+ ]
38
+ missing_vars = [var for var in required_env_vars if not os.getenv(var)]
39
+ if missing_vars:
40
+ raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}")
41
+
42
+ # ------------------------------------------------------------------------------
43
+ # Configure logging (consider structured logging in production)
44
+ # ------------------------------------------------------------------------------
45
+ logging.basicConfig(
46
+ level=logging.INFO,
47
+ format="%(asctime)s - %(levelname)s - %(message)s",
48
+ handlers=[logging.StreamHandler()]
49
+ )
50
+ logger = logging.getLogger(__name__)
51
+
52
+ # ------------------------------------------------------------------------------
53
+ # Initialize FastAPI app with CORS middleware (restrict origins in production)
54
+ # ------------------------------------------------------------------------------
55
+ app = FastAPI()
56
  app.add_middleware(
57
  CORSMiddleware,
58
+ allow_origins=["*"],
59
  allow_credentials=True,
60
  allow_methods=["*"],
61
  allow_headers=["*"],
62
  )
63
 
64
+ # ------------------------------------------------------------------------------
65
+ # Initialize external services
66
+ # ------------------------------------------------------------------------------
67
+ try:
68
+ openai_client = AsyncOpenAI(
69
+ api_key=os.getenv("OPENAI_API_KEY"),
70
+ )
71
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
72
+ embed_model = HuggingFaceEmbeddings(
73
+ model_name="Snowflake/snowflake-arctic-embed-l-v2.0",
74
+ model_kwargs={"trust_remote_code": True}
75
+ )
76
+ except Exception as e:
77
+ logger.error(f"Service initialization error: {e}")
78
+ raise
79
 
80
+ # ------------------------------------------------------------------------------
81
+ # System prompt for the chat model
82
+ # ------------------------------------------------------------------------------
83
+ system_prompt = """ You are an **advanced AI assistant developed by lawa.ai**, designed to provide **precise, fact-based, and well-structured** responses to user queries. Your responses should be based **only** on the provided context, ensuring **accuracy, clarity, and transparency**.
 
 
 
 
84
 
85
+ If the context **does not contain** the answer, **state this explicitly** rather than guessing or making assumptions.
86
 
87
+ ---
88
+ ### **📌 Response Guidelines**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ #### **1️⃣ Precision & Clarity**
91
+ - Format responses in **Markdown** for enhanced readability.
92
+ - Match the **response language** to the query's "Language" field.
93
+ - Ensure responses are **concise yet comprehensive**, avoiding excessive elaboration.
94
+
95
+ #### **2️⃣ Citing Sources Transparently**
96
+ - Use **numerical citations** ([1], [2], etc.) to indicate the source document of the information.
97
+ - Citations must be **placed immediately after the relevant statement**.
98
+ - Ensure citations map correctly to the order of documents in the provided context.
99
+
100
+ #### **3️⃣ Formatting for Readability**
101
+ - Use **bold text**, *italic text*, bullet points, and headings for emphasis.
102
+ - Organize responses into **logical sections** to improve structure.
103
+ - Provide **tables or bullet points** where appropriate for numerical/statistical data.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ #### **4️⃣ Strictly Adhere to Context**
106
+ - Use **only** information from the provided context.
107
+ - **Do not** include external knowledge or speculate on missing details.
108
 
109
+ #### **5️⃣ Handling Missing or Insufficient Context**
110
+ - If the context does **not contain** a clear answer, respond with:
111
+ 🛑 *"The provided context does not contain relevant information to answer your question."*
112
+ - If general knowledge is allowed, provide a well-informed but **non-speculative** response.
113
 
114
+ #### **6️⃣ Avoiding AI Hallucinations**
115
+ - **Do not fabricate data, statistics, or references**.
116
+ - **Do not assume missing details**—state explicitly if something is unclear.
117
 
118
+ #### **7️⃣ Self-Identification When Asked**
119
+ - If requested, clearly state:
120
+ *"I am an AI assistant developed by lawa.ai, designed to provide accurate responses based on provided context."*
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ ---
123
+ ### **📌 Strict Rules for Response Generation**
124
+ ✅ **Never mention the word "context" in responses.**
125
+ ✅ **Use only the relevant content from the provided context.**
126
+ ✅ **If no relevant information exists, say so explicitly.**
127
 
128
+ ---
129
+ ### **📌 Input Format Example**
130
+ **User Query:**
131
+ *"What are the latest updates on the scholarship policies at MBZUAI?"*
132
+ **Language:** *English*
133
+ **Context:**
134
+ ```text
135
+ <provided context>
136
+ ```
137
 
138
+ ---
139
+ ### **📌 Expected Output Format**
140
+ ```markdown
141
+ ### **Latest Updates on MBZUAI Scholarship Policies**
142
+ MBZUAI recently updated its scholarship policies to include the following:
143
+
144
+ 1. **Scholarship Coverage:** Full tuition fees, accommodation, and a monthly stipend. [1]
145
+ 2. **Eligibility Criteria:** Applicants must maintain a GPA of 3.5 or higher. [2]
146
+
147
+ For further details, please refer to the official documents. If you have more specific questions, feel free to ask!
148
+ ```
149
+
150
+ ---
151
+ ### **📌 Example Question & Response**
152
+ #### **User Query:**
153
+ *"I overstayed my tourist visa in the UAE. What penalties or fines will I face, and how can I resolve this legally?"*
154
+ #### **Provided Context:**
155
+ ```text
156
+ <related regulations on visa overstay penalties>
157
+ ```
158
+ #### **Generated Response:**
159
+ ```markdown
160
+ ### **UAE Tourist Visa Overstay Penalties**
161
+ Overstaying a UAE tourist visa incurs specific penalties and requires prompt action to avoid legal issues.
162
+
163
+ #### **Fines & Fees**
164
+ - **Daily Fine:** AED 50 per day beyond the visa expiry. [1]
165
+ - **Exit Fee:** Additional AED 200 upon departure. [2]
166
+
167
+ #### **Steps to Resolve the Issue**
168
+ 1. **Calculate Total Fines:** Multiply overstayed days by AED 50 and add any exit fees.
169
+ 2. **Visit an Immigration Office:** Report to the General Directorate of Residency and Foreigners Affairs (GDRFA) or an Amer service center in Dubai.
170
+ 3. **Pay the Fines:** Payments can be made at immigration offices, airports, land borders, or seaports upon departure. [3]
171
+ 4. **Apply for a Visa Extension:** If you wish to stay longer, request a visa extension or status change before expiry. [4]
172
+
173
+ #### **Additional Considerations**
174
+ - **Grace Period:** Some visas offer a grace period before fines apply. [5]
175
+ - **Legal Assistance:** If needed, consult immigration experts for further guidance.
176
+
177
+ Acting promptly helps minimize fines and maintain a clean immigration record in the UAE.
178
+ ```
179
  """
 
 
 
 
 
 
 
180
 
181
+ # ------------------------------------------------------------------------------
182
+ # Pydantic models for request/response validation
183
+ # ------------------------------------------------------------------------------
184
+ class ChatRequest(BaseModel):
185
+ question: str = Field(..., max_length=1024)
186
+ language: str
187
+ previous_chats: List[dict]
188
 
189
+ class CitationSource(BaseModel):
190
+ url: str
191
+ cite_num: str
192
 
193
+ # ------------------------------------------------------------------------------
194
+ # Initialize Pinecone retriever with retries
195
+ # ------------------------------------------------------------------------------
196
+ MAX_RETRIES = 3
197
 
198
+ def initialize_pinecone():
199
+ for attempt in range(MAX_RETRIES):
200
+ try:
201
+ index = pc.Index("saudi-arabia-moj")
202
+ bm25 = BM25Encoder().load("./saudi-arabia-moj.json")
203
+ return PineconeHybridSearchRetriever(
204
+ embeddings=embed_model,
205
+ sparse_encoder=bm25,
206
+ index=index,
207
+ top_k=40, # Hardcoded as required
208
+ alpha=0.6, # Hardcoded as required
209
+ )
210
+ except Exception as e:
211
+ logger.warning(f"Pinecone initialization attempt {attempt + 1} failed: {e}")
212
+ if attempt == MAX_RETRIES - 1:
213
+ raise
214
+ time.sleep(2 ** attempt)
215
 
216
+ retriever = initialize_pinecone()
217
+
218
+ # ------------------------------------------------------------------------------
219
+ # Utility function to send messages safely over the websocket
220
+ # ------------------------------------------------------------------------------
221
+ async def safe_send(websocket: WebSocket, message: dict):
222
+ try:
223
+ await websocket.send_json(message)
224
+ except WebSocketDisconnect:
225
+ logger.info("Client disconnected during send")
226
+ raise
227
+ except Exception as e:
228
+ logger.error(f"Error sending message: {e}")
229
+ raise
230
+
231
+ # ------------------------------------------------------------------------------
232
+ # Helper functions for document processing and query formatting
233
+ # ------------------------------------------------------------------------------
234
+ def rerank_docs(query: str, docs: List[dict], pc_client: Pinecone) -> List[dict]:
235
+ try:
236
+ result = pc_client.inference.rerank(
237
+ model="cohere-rerank-3.5",
238
+ query=query,
239
+ documents=docs,
240
+ rank_fields=["chunk"],
241
+ top_n=20,
242
+ return_documents=True
243
+ )
244
+ ranked_docs = [{
245
+ "page_source": ele.document.page_source,
246
+ "chunk": ele.document.chunk,
247
+ "summary": ele.document.summary
248
+ } for ele in result.data]
249
+ return ranked_docs
250
+ except Exception as e:
251
+ logger.error(f"Error in rerank_docs: {e}")
252
+ raise
253
+
254
+ def format_docs(docs: List[dict]) -> str:
255
+ context = ""
256
+ for index, ele in enumerate(docs):
257
+ context += (
258
+ f"\n{'=' * 150}\n"
259
+ f"**DOCUMENT:** {index + 1}\n"
260
+ f"**SOURCE:** {ele['page_source']}\n\n"
261
+ f"**CONTENT:** {ele['chunk']}\n\n"
262
+ )
263
+ return context
264
+
265
+ def format_query(query: str, language: str, docs: List[dict]) -> str:
266
+ formatted_docs = format_docs(docs)
267
+ return f"**USER QUERY:** {query}\n**LANGUAGE:** {language}\n**CONTEXT:**\n{formatted_docs}"
268
+
269
+ def validate_citation_numbers(citation_numbers: List[int], max_docs: int) -> List[int]:
270
+ return [num for num in citation_numbers if 1 <= num <= max_docs]
271
+
272
+ def process_citations(complete_answer: str, ranked_docs: List[dict]) -> Tuple[str, List[dict]]:
273
+ """
274
+ Extracts citation numbers from the answer, maps them to consecutive citation numbers,
275
+ and returns the updated answer along with a list of citation sources.
276
+ """
277
+ citations = []
278
+ seen_nums = set()
279
+ citation_numbers = []
280
+ for num_str in re.findall(r'\[(\d+)\]', complete_answer):
281
+ num = int(num_str)
282
+ if num not in seen_nums:
283
+ seen_nums.add(num)
284
+ citation_numbers.append(num)
285
+ valid_citations = validate_citation_numbers(citation_numbers, len(ranked_docs))
286
+
287
+ seen_urls = {}
288
+ citation_map = {}
289
+ current_num = 1
290
+ for num in valid_citations:
291
+ try:
292
+ url = ranked_docs[num - 1]["page_source"]
293
+ if url not in seen_urls:
294
+ citation_map[num] = current_num
295
+ seen_urls[url] = current_num
296
+ citations.append({"url": url, "cite_num": str(current_num)})
297
+ current_num += 1
298
+ else:
299
+ citation_map[num] = seen_urls[url]
300
+ except IndexError:
301
+ continue
302
 
303
+ logger.debug(f"Citation numbers extracted: {citation_numbers}")
304
+ logger.debug(f"Seen URLs mapping: {seen_urls}")
305
 
306
+ def replace_citation(match):
307
+ original = int(match.group(1))
308
+ new_num = citation_map.get(original, original)
309
+ url = next((c["url"] for c in citations if c["cite_num"] == str(new_num)), "")
310
+ return f"[{new_num}]({url})" if url else f"[{new_num}]"
311
+
312
+ updated_answer = re.sub(r'\[(\d+)\]', replace_citation, complete_answer)
313
+ return updated_answer, sorted(citations, key=lambda x: int(x["cite_num"]))
314
+
315
+ # ------------------------------------------------------------------------------
316
+ # WebSocket endpoint for chat functionality with improved error handling
317
+ # ------------------------------------------------------------------------------
318
+ @app.websocket("/chat")
319
  async def websocket_endpoint(websocket: WebSocket):
320
  await websocket.accept()
 
 
321
  try:
322
+ # Receive and validate the request
323
+ try:
324
+ data = await asyncio.wait_for(websocket.receive_json(), timeout=30)
325
+ chat_request = ChatRequest(**data)
326
+ except ValidationError as e:
327
+ logger.error(f"Validation error: {e}")
328
+ await safe_send(websocket, {"response": "Something went wrong with your request!", "sources": []})
329
+ return
330
+ except Exception as e:
331
+ logger.error(f"Error receiving data: {e}")
332
+ await safe_send(websocket, {"response": "Something went wrong with your request!", "sources": []})
333
+ return
334
+
335
+ question = chat_request.question
336
+ language = chat_request.language
337
+
338
+ # Retrieve documents using the retriever
339
+ try:
340
+ retrieved_docs = await asyncio.to_thread(retriever.invoke, question)
341
+ except Exception as e:
342
+ logger.error(f"Document retrieval error: {e}")
343
+ await safe_send(websocket, {"response": "Document retrieval failed", "sources": []})
344
+ return
345
+
346
+ docs = [{
347
+ "summary": ele.metadata.get("summary", ""),
348
+ "chunk": ele.page_content,
349
+ "page_source": ele.metadata.get("source", "")
350
+ } for ele in retrieved_docs]
351
+
352
+ if not docs:
353
+ await safe_send(websocket, {"response": "Cannot provide answer to this question", "sources": []})
354
+ return
355
+
356
+ # Rerank the documents (fallback to original docs if reranking fails)
357
+ try:
358
+ ranked_docs = await asyncio.to_thread(rerank_docs, question, docs, pc)
359
+ except Exception as e:
360
+ logger.error(f"Reranking error: {e}")
361
+ ranked_docs = docs
362
+
363
+ # Prepare the conversation messages
364
+ messages = [{"role": "system", "content": system_prompt}]
365
+ messages.extend(chat_request.previous_chats)
366
+ messages.append({"role": "user", "content": format_query(question, language, ranked_docs)})
367
+
368
+ complete_answer = ""
369
+ chunk_buffer = ""
370
+
371
+ # Generate and stream the chat response
372
+ try:
373
+ completion = await openai_client.chat.completions.create(
374
+ model="gpt-4o",
375
+ messages=messages,
376
+ temperature=0.2,
377
+ max_completion_tokens=1024,
378
+ stream=True
379
+ )
380
+ async for chunk in completion:
381
+ delta_content = chunk.choices[0].delta.content
382
+ if delta_content:
383
+ complete_answer += delta_content
384
+ # Remove inline citation markers from the streamed chunk before sending
385
+ cleaned_content = re.sub(r'\[\d+\]', '', delta_content)
386
+ chunk_buffer += cleaned_content
387
+ if len(chunk_buffer) >= 1:
388
+ await safe_send(websocket, {"response": chunk_buffer})
389
+ chunk_buffer = ""
390
+ if chunk_buffer:
391
+ await safe_send(websocket, {"response": chunk_buffer})
392
+ except Exception as e:
393
+ logger.error(f"Streaming error: {e}")
394
+ await safe_send(websocket, {"response": "Response generation failed", "sources": []})
395
+ return
396
+
397
+ # Process and map citations in the final answer
398
+ complete_answer, citations = process_citations(complete_answer, ranked_docs)
399
+
400
+ await safe_send(websocket, {
401
+ "response": complete_answer,
402
+ "sources": citations
403
+ })
404
+
405
  except WebSocketDisconnect:
406
+ logger.info("Client disconnected")
407
+ except Exception as e:
408
+ logger.error(f"Unexpected error: {e}")
409
+ await safe_send(websocket, {"response": "Something went wrong! Please try again.", "sources": []})
410
+
411
+ # ------------------------------------------------------------------------------
412
+ # Simple health check endpoint
413
+ # ------------------------------------------------------------------------------
414
+ @app.get("/")
415
+ async def root():
416
+ return JSONResponse(content={"message": "working"})