Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,245 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import nltk
|
2 |
-
nltk.
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
import os
|
5 |
from dotenv import load_dotenv
|
6 |
-
import
|
7 |
-
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
8 |
-
from fastapi.responses import HTMLResponse
|
9 |
-
from fastapi.templating import Jinja2Templates
|
10 |
from fastapi.middleware.cors import CORSMiddleware
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from
|
14 |
-
|
15 |
-
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
16 |
-
from langchain_core.runnables.history import RunnableWithMessageHistory
|
17 |
from pinecone import Pinecone
|
18 |
from pinecone_text.sparse import BM25Encoder
|
19 |
-
from langchain_huggingface import HuggingFaceEmbeddings
|
20 |
from langchain_community.retrievers import PineconeHybridSearchRetriever
|
21 |
-
from
|
22 |
-
from
|
23 |
-
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
24 |
-
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
25 |
-
from langchain_core.prompts import PromptTemplate
|
26 |
-
from langchain.retrievers.document_compressors import FlashrankRerank
|
27 |
-
import re
|
28 |
|
29 |
-
|
|
|
|
|
|
|
30 |
load_dotenv(".env")
|
31 |
-
USER_AGENT = os.getenv("USER_AGENT")
|
32 |
-
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
33 |
-
SECRET_KEY = os.getenv("SECRET_KEY")
|
34 |
-
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
|
35 |
-
SESSION_ID_DEFAULT = "abc123"
|
36 |
-
|
37 |
-
# Set environment variables
|
38 |
-
os.environ['USER_AGENT'] = USER_AGENT
|
39 |
-
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
|
40 |
-
os.environ["TOKENIZERS_PARALLELISM"] = 'true'
|
41 |
-
|
42 |
-
# Initialize FastAPI app and CORS
|
43 |
-
app = FastAPI()
|
44 |
-
origins = ["*"] # Adjust as needed
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
app.add_middleware(
|
47 |
CORSMiddleware,
|
48 |
-
allow_origins=
|
49 |
allow_credentials=True,
|
50 |
allow_methods=["*"],
|
51 |
allow_headers=["*"],
|
52 |
)
|
53 |
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
-
#
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
return pc.Index(index_name)
|
61 |
-
except Exception as e:
|
62 |
-
print(f"Error initializing Pinecone: {e}")
|
63 |
-
raise
|
64 |
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
##################################################
|
69 |
-
# #### This is for UAE Economic Department Website
|
70 |
-
pinecone_index = initialize_pinecone("updated-saudi-arabia-ministry-of-justice")
|
71 |
-
bm25 = BM25Encoder().load("./updated-saudi-arabia-bm25-encoder.json")
|
72 |
-
##################################################
|
73 |
-
##################################################
|
74 |
-
|
75 |
-
# Initialize models and retriever
|
76 |
-
embed_model = HuggingFaceEmbeddings(model_name="jinaai/jina-embeddings-v3", model_kwargs={"trust_remote_code":True})
|
77 |
-
retriever = PineconeHybridSearchRetriever(
|
78 |
-
embeddings=embed_model,
|
79 |
-
sparse_encoder=bm25,
|
80 |
-
index=pinecone_index,
|
81 |
-
top_k=10,
|
82 |
-
alpha=0.5,
|
83 |
-
)
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
# base_compressor=compressor, base_retriever=retriever
|
100 |
-
# )
|
101 |
-
|
102 |
-
# compressor = FlashrankRerank(top_n=10)
|
103 |
-
# compression_retriever = ContextualCompressionRetriever(
|
104 |
-
# base_compressor=compressor, base_retriever=retriever
|
105 |
-
# )
|
106 |
-
|
107 |
-
# Contextualization prompt and retriever
|
108 |
-
contextualize_q_system_prompt = """ Given a chat history and the latest user question \
|
109 |
-
which might reference context in the chat history, formulate a standalone question \
|
110 |
-
which can be understood without the chat history. Do NOT answer the question, \
|
111 |
-
just reformulate it if needed and otherwise return it as is.
|
112 |
-
"""
|
113 |
-
contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
114 |
-
[
|
115 |
-
("system", contextualize_q_system_prompt),
|
116 |
-
MessagesPlaceholder("chat_history"),
|
117 |
-
("human", "{input}")
|
118 |
-
]
|
119 |
-
)
|
120 |
-
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
|
126 |
-
|
|
|
|
|
|
|
127 |
|
128 |
-
|
|
|
|
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
- Do not give long answers. Provide detailed but concise responses.
|
134 |
-
|
135 |
-
2. Formatting for Readability:
|
136 |
-
- Provide the entire response in proper markdown format.
|
137 |
-
- Use structured Markdown elements such as headings, subheadings, lists, tables, and links.
|
138 |
-
- Use emphasis on headings, important texts, and phrases.
|
139 |
-
|
140 |
-
3. Proper Citations:
|
141 |
-
- Always use inline citations with embedded source URLs.
|
142 |
-
- The inline citations should be in the format [1], [2], etc.
|
143 |
-
- DO NOT INCLUDE THE 'References' SECTION IN THE RESPONSE.
|
144 |
|
145 |
-
|
|
|
|
|
|
|
|
|
146 |
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
"""
|
151 |
-
qa_prompt = ChatPromptTemplate.from_messages(
|
152 |
-
[
|
153 |
-
("system", qa_system_prompt),
|
154 |
-
MessagesPlaceholder("chat_history"),
|
155 |
-
("human", "{input}")
|
156 |
-
]
|
157 |
-
)
|
158 |
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
-
|
163 |
-
|
|
|
164 |
|
165 |
-
#
|
166 |
-
|
|
|
|
|
167 |
|
168 |
-
def
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
|
|
|
|
183 |
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
async def websocket_endpoint(websocket: WebSocket):
|
187 |
await websocket.accept()
|
188 |
-
print(f"Client connected: {websocket.client}")
|
189 |
-
session_id = None
|
190 |
try:
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
except WebSocketDisconnect:
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
import logging
|
6 |
+
|
7 |
import nltk
|
8 |
+
# Pre-download the required nltk resource if not already available.
|
9 |
+
try:
|
10 |
+
nltk.data.find('tokenizers/punkt_tab')
|
11 |
+
except LookupError:
|
12 |
+
nltk.download('punkt_tab')
|
13 |
|
|
|
14 |
from dotenv import load_dotenv
|
15 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
|
|
|
|
|
16 |
from fastapi.middleware.cors import CORSMiddleware
|
17 |
+
from fastapi.responses import JSONResponse
|
18 |
+
from pydantic import BaseModel, Field, ValidationError
|
19 |
+
from typing import List, Dict, Tuple
|
20 |
+
|
|
|
|
|
21 |
from pinecone import Pinecone
|
22 |
from pinecone_text.sparse import BM25Encoder
|
|
|
23 |
from langchain_community.retrievers import PineconeHybridSearchRetriever
|
24 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
25 |
+
from openai import AsyncOpenAI
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
|
28 |
+
# ------------------------------------------------------------------------------
|
29 |
+
# Load environment variables and validate required ones
|
30 |
+
# ------------------------------------------------------------------------------
|
31 |
load_dotenv(".env")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
required_env_vars = [
|
34 |
+
"PINECONE_API_KEY",
|
35 |
+
"PERPLEXITY_API_KEY",
|
36 |
+
"OPENAI_API_KEY" # Ensure the OpenAI API key is provided
|
37 |
+
]
|
38 |
+
missing_vars = [var for var in required_env_vars if not os.getenv(var)]
|
39 |
+
if missing_vars:
|
40 |
+
raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}")
|
41 |
+
|
42 |
+
# ------------------------------------------------------------------------------
|
43 |
+
# Configure logging (consider structured logging in production)
|
44 |
+
# ------------------------------------------------------------------------------
|
45 |
+
logging.basicConfig(
|
46 |
+
level=logging.INFO,
|
47 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
48 |
+
handlers=[logging.StreamHandler()]
|
49 |
+
)
|
50 |
+
logger = logging.getLogger(__name__)
|
51 |
+
|
52 |
+
# ------------------------------------------------------------------------------
|
53 |
+
# Initialize FastAPI app with CORS middleware (restrict origins in production)
|
54 |
+
# ------------------------------------------------------------------------------
|
55 |
+
app = FastAPI()
|
56 |
app.add_middleware(
|
57 |
CORSMiddleware,
|
58 |
+
allow_origins=["*"],
|
59 |
allow_credentials=True,
|
60 |
allow_methods=["*"],
|
61 |
allow_headers=["*"],
|
62 |
)
|
63 |
|
64 |
+
# ------------------------------------------------------------------------------
|
65 |
+
# Initialize external services
|
66 |
+
# ------------------------------------------------------------------------------
|
67 |
+
try:
|
68 |
+
openai_client = AsyncOpenAI(
|
69 |
+
api_key=os.getenv("OPENAI_API_KEY"),
|
70 |
+
)
|
71 |
+
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
|
72 |
+
embed_model = HuggingFaceEmbeddings(
|
73 |
+
model_name="Snowflake/snowflake-arctic-embed-l-v2.0",
|
74 |
+
model_kwargs={"trust_remote_code": True}
|
75 |
+
)
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Service initialization error: {e}")
|
78 |
+
raise
|
79 |
|
80 |
+
# ------------------------------------------------------------------------------
|
81 |
+
# System prompt for the chat model
|
82 |
+
# ------------------------------------------------------------------------------
|
83 |
+
system_prompt = """ You are an **advanced AI assistant developed by lawa.ai**, designed to provide **precise, fact-based, and well-structured** responses to user queries. Your responses should be based **only** on the provided context, ensuring **accuracy, clarity, and transparency**.
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
If the context **does not contain** the answer, **state this explicitly** rather than guessing or making assumptions.
|
86 |
|
87 |
+
---
|
88 |
+
### **📌 Response Guidelines**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
+
#### **1️⃣ Precision & Clarity**
|
91 |
+
- Format responses in **Markdown** for enhanced readability.
|
92 |
+
- Match the **response language** to the query's "Language" field.
|
93 |
+
- Ensure responses are **concise yet comprehensive**, avoiding excessive elaboration.
|
94 |
+
|
95 |
+
#### **2️⃣ Citing Sources Transparently**
|
96 |
+
- Use **numerical citations** ([1], [2], etc.) to indicate the source document of the information.
|
97 |
+
- Citations must be **placed immediately after the relevant statement**.
|
98 |
+
- Ensure citations map correctly to the order of documents in the provided context.
|
99 |
+
|
100 |
+
#### **3️⃣ Formatting for Readability**
|
101 |
+
- Use **bold text**, *italic text*, bullet points, and headings for emphasis.
|
102 |
+
- Organize responses into **logical sections** to improve structure.
|
103 |
+
- Provide **tables or bullet points** where appropriate for numerical/statistical data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
+
#### **4️⃣ Strictly Adhere to Context**
|
106 |
+
- Use **only** information from the provided context.
|
107 |
+
- **Do not** include external knowledge or speculate on missing details.
|
108 |
|
109 |
+
#### **5️⃣ Handling Missing or Insufficient Context**
|
110 |
+
- If the context does **not contain** a clear answer, respond with:
|
111 |
+
🛑 *"The provided context does not contain relevant information to answer your question."*
|
112 |
+
- If general knowledge is allowed, provide a well-informed but **non-speculative** response.
|
113 |
|
114 |
+
#### **6️⃣ Avoiding AI Hallucinations**
|
115 |
+
- **Do not fabricate data, statistics, or references**.
|
116 |
+
- **Do not assume missing details**—state explicitly if something is unclear.
|
117 |
|
118 |
+
#### **7️⃣ Self-Identification When Asked**
|
119 |
+
- If requested, clearly state:
|
120 |
+
*"I am an AI assistant developed by lawa.ai, designed to provide accurate responses based on provided context."*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
+
---
|
123 |
+
### **📌 Strict Rules for Response Generation**
|
124 |
+
✅ **Never mention the word "context" in responses.**
|
125 |
+
✅ **Use only the relevant content from the provided context.**
|
126 |
+
✅ **If no relevant information exists, say so explicitly.**
|
127 |
|
128 |
+
---
|
129 |
+
### **📌 Input Format Example**
|
130 |
+
**User Query:**
|
131 |
+
*"What are the latest updates on the scholarship policies at MBZUAI?"*
|
132 |
+
**Language:** *English*
|
133 |
+
**Context:**
|
134 |
+
```text
|
135 |
+
<provided context>
|
136 |
+
```
|
137 |
|
138 |
+
---
|
139 |
+
### **📌 Expected Output Format**
|
140 |
+
```markdown
|
141 |
+
### **Latest Updates on MBZUAI Scholarship Policies**
|
142 |
+
MBZUAI recently updated its scholarship policies to include the following:
|
143 |
+
|
144 |
+
1. **Scholarship Coverage:** Full tuition fees, accommodation, and a monthly stipend. [1]
|
145 |
+
2. **Eligibility Criteria:** Applicants must maintain a GPA of 3.5 or higher. [2]
|
146 |
+
|
147 |
+
For further details, please refer to the official documents. If you have more specific questions, feel free to ask!
|
148 |
+
```
|
149 |
+
|
150 |
+
---
|
151 |
+
### **📌 Example Question & Response**
|
152 |
+
#### **User Query:**
|
153 |
+
*"I overstayed my tourist visa in the UAE. What penalties or fines will I face, and how can I resolve this legally?"*
|
154 |
+
#### **Provided Context:**
|
155 |
+
```text
|
156 |
+
<related regulations on visa overstay penalties>
|
157 |
+
```
|
158 |
+
#### **Generated Response:**
|
159 |
+
```markdown
|
160 |
+
### **UAE Tourist Visa Overstay Penalties**
|
161 |
+
Overstaying a UAE tourist visa incurs specific penalties and requires prompt action to avoid legal issues.
|
162 |
+
|
163 |
+
#### **Fines & Fees**
|
164 |
+
- **Daily Fine:** AED 50 per day beyond the visa expiry. [1]
|
165 |
+
- **Exit Fee:** Additional AED 200 upon departure. [2]
|
166 |
+
|
167 |
+
#### **Steps to Resolve the Issue**
|
168 |
+
1. **Calculate Total Fines:** Multiply overstayed days by AED 50 and add any exit fees.
|
169 |
+
2. **Visit an Immigration Office:** Report to the General Directorate of Residency and Foreigners Affairs (GDRFA) or an Amer service center in Dubai.
|
170 |
+
3. **Pay the Fines:** Payments can be made at immigration offices, airports, land borders, or seaports upon departure. [3]
|
171 |
+
4. **Apply for a Visa Extension:** If you wish to stay longer, request a visa extension or status change before expiry. [4]
|
172 |
+
|
173 |
+
#### **Additional Considerations**
|
174 |
+
- **Grace Period:** Some visas offer a grace period before fines apply. [5]
|
175 |
+
- **Legal Assistance:** If needed, consult immigration experts for further guidance.
|
176 |
+
|
177 |
+
Acting promptly helps minimize fines and maintain a clean immigration record in the UAE.
|
178 |
+
```
|
179 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
+
# ------------------------------------------------------------------------------
|
182 |
+
# Pydantic models for request/response validation
|
183 |
+
# ------------------------------------------------------------------------------
|
184 |
+
class ChatRequest(BaseModel):
|
185 |
+
question: str = Field(..., max_length=1024)
|
186 |
+
language: str
|
187 |
+
previous_chats: List[dict]
|
188 |
|
189 |
+
class CitationSource(BaseModel):
|
190 |
+
url: str
|
191 |
+
cite_num: str
|
192 |
|
193 |
+
# ------------------------------------------------------------------------------
|
194 |
+
# Initialize Pinecone retriever with retries
|
195 |
+
# ------------------------------------------------------------------------------
|
196 |
+
MAX_RETRIES = 3
|
197 |
|
198 |
+
def initialize_pinecone():
|
199 |
+
for attempt in range(MAX_RETRIES):
|
200 |
+
try:
|
201 |
+
index = pc.Index("saudi-arabia-moj")
|
202 |
+
bm25 = BM25Encoder().load("./saudi-arabia-moj.json")
|
203 |
+
return PineconeHybridSearchRetriever(
|
204 |
+
embeddings=embed_model,
|
205 |
+
sparse_encoder=bm25,
|
206 |
+
index=index,
|
207 |
+
top_k=40, # Hardcoded as required
|
208 |
+
alpha=0.6, # Hardcoded as required
|
209 |
+
)
|
210 |
+
except Exception as e:
|
211 |
+
logger.warning(f"Pinecone initialization attempt {attempt + 1} failed: {e}")
|
212 |
+
if attempt == MAX_RETRIES - 1:
|
213 |
+
raise
|
214 |
+
time.sleep(2 ** attempt)
|
215 |
|
216 |
+
retriever = initialize_pinecone()
|
217 |
+
|
218 |
+
# ------------------------------------------------------------------------------
|
219 |
+
# Utility function to send messages safely over the websocket
|
220 |
+
# ------------------------------------------------------------------------------
|
221 |
+
async def safe_send(websocket: WebSocket, message: dict):
|
222 |
+
try:
|
223 |
+
await websocket.send_json(message)
|
224 |
+
except WebSocketDisconnect:
|
225 |
+
logger.info("Client disconnected during send")
|
226 |
+
raise
|
227 |
+
except Exception as e:
|
228 |
+
logger.error(f"Error sending message: {e}")
|
229 |
+
raise
|
230 |
+
|
231 |
+
# ------------------------------------------------------------------------------
|
232 |
+
# Helper functions for document processing and query formatting
|
233 |
+
# ------------------------------------------------------------------------------
|
234 |
+
def rerank_docs(query: str, docs: List[dict], pc_client: Pinecone) -> List[dict]:
|
235 |
+
try:
|
236 |
+
result = pc_client.inference.rerank(
|
237 |
+
model="cohere-rerank-3.5",
|
238 |
+
query=query,
|
239 |
+
documents=docs,
|
240 |
+
rank_fields=["chunk"],
|
241 |
+
top_n=20,
|
242 |
+
return_documents=True
|
243 |
+
)
|
244 |
+
ranked_docs = [{
|
245 |
+
"page_source": ele.document.page_source,
|
246 |
+
"chunk": ele.document.chunk,
|
247 |
+
"summary": ele.document.summary
|
248 |
+
} for ele in result.data]
|
249 |
+
return ranked_docs
|
250 |
+
except Exception as e:
|
251 |
+
logger.error(f"Error in rerank_docs: {e}")
|
252 |
+
raise
|
253 |
+
|
254 |
+
def format_docs(docs: List[dict]) -> str:
|
255 |
+
context = ""
|
256 |
+
for index, ele in enumerate(docs):
|
257 |
+
context += (
|
258 |
+
f"\n{'=' * 150}\n"
|
259 |
+
f"**DOCUMENT:** {index + 1}\n"
|
260 |
+
f"**SOURCE:** {ele['page_source']}\n\n"
|
261 |
+
f"**CONTENT:** {ele['chunk']}\n\n"
|
262 |
+
)
|
263 |
+
return context
|
264 |
+
|
265 |
+
def format_query(query: str, language: str, docs: List[dict]) -> str:
|
266 |
+
formatted_docs = format_docs(docs)
|
267 |
+
return f"**USER QUERY:** {query}\n**LANGUAGE:** {language}\n**CONTEXT:**\n{formatted_docs}"
|
268 |
+
|
269 |
+
def validate_citation_numbers(citation_numbers: List[int], max_docs: int) -> List[int]:
|
270 |
+
return [num for num in citation_numbers if 1 <= num <= max_docs]
|
271 |
+
|
272 |
+
def process_citations(complete_answer: str, ranked_docs: List[dict]) -> Tuple[str, List[dict]]:
|
273 |
+
"""
|
274 |
+
Extracts citation numbers from the answer, maps them to consecutive citation numbers,
|
275 |
+
and returns the updated answer along with a list of citation sources.
|
276 |
+
"""
|
277 |
+
citations = []
|
278 |
+
seen_nums = set()
|
279 |
+
citation_numbers = []
|
280 |
+
for num_str in re.findall(r'\[(\d+)\]', complete_answer):
|
281 |
+
num = int(num_str)
|
282 |
+
if num not in seen_nums:
|
283 |
+
seen_nums.add(num)
|
284 |
+
citation_numbers.append(num)
|
285 |
+
valid_citations = validate_citation_numbers(citation_numbers, len(ranked_docs))
|
286 |
+
|
287 |
+
seen_urls = {}
|
288 |
+
citation_map = {}
|
289 |
+
current_num = 1
|
290 |
+
for num in valid_citations:
|
291 |
+
try:
|
292 |
+
url = ranked_docs[num - 1]["page_source"]
|
293 |
+
if url not in seen_urls:
|
294 |
+
citation_map[num] = current_num
|
295 |
+
seen_urls[url] = current_num
|
296 |
+
citations.append({"url": url, "cite_num": str(current_num)})
|
297 |
+
current_num += 1
|
298 |
+
else:
|
299 |
+
citation_map[num] = seen_urls[url]
|
300 |
+
except IndexError:
|
301 |
+
continue
|
302 |
|
303 |
+
logger.debug(f"Citation numbers extracted: {citation_numbers}")
|
304 |
+
logger.debug(f"Seen URLs mapping: {seen_urls}")
|
305 |
|
306 |
+
def replace_citation(match):
|
307 |
+
original = int(match.group(1))
|
308 |
+
new_num = citation_map.get(original, original)
|
309 |
+
url = next((c["url"] for c in citations if c["cite_num"] == str(new_num)), "")
|
310 |
+
return f"[{new_num}]({url})" if url else f"[{new_num}]"
|
311 |
+
|
312 |
+
updated_answer = re.sub(r'\[(\d+)\]', replace_citation, complete_answer)
|
313 |
+
return updated_answer, sorted(citations, key=lambda x: int(x["cite_num"]))
|
314 |
+
|
315 |
+
# ------------------------------------------------------------------------------
|
316 |
+
# WebSocket endpoint for chat functionality with improved error handling
|
317 |
+
# ------------------------------------------------------------------------------
|
318 |
+
@app.websocket("/chat")
|
319 |
async def websocket_endpoint(websocket: WebSocket):
|
320 |
await websocket.accept()
|
|
|
|
|
321 |
try:
|
322 |
+
# Receive and validate the request
|
323 |
+
try:
|
324 |
+
data = await asyncio.wait_for(websocket.receive_json(), timeout=30)
|
325 |
+
chat_request = ChatRequest(**data)
|
326 |
+
except ValidationError as e:
|
327 |
+
logger.error(f"Validation error: {e}")
|
328 |
+
await safe_send(websocket, {"response": "Something went wrong with your request!", "sources": []})
|
329 |
+
return
|
330 |
+
except Exception as e:
|
331 |
+
logger.error(f"Error receiving data: {e}")
|
332 |
+
await safe_send(websocket, {"response": "Something went wrong with your request!", "sources": []})
|
333 |
+
return
|
334 |
+
|
335 |
+
question = chat_request.question
|
336 |
+
language = chat_request.language
|
337 |
+
|
338 |
+
# Retrieve documents using the retriever
|
339 |
+
try:
|
340 |
+
retrieved_docs = await asyncio.to_thread(retriever.invoke, question)
|
341 |
+
except Exception as e:
|
342 |
+
logger.error(f"Document retrieval error: {e}")
|
343 |
+
await safe_send(websocket, {"response": "Document retrieval failed", "sources": []})
|
344 |
+
return
|
345 |
+
|
346 |
+
docs = [{
|
347 |
+
"summary": ele.metadata.get("summary", ""),
|
348 |
+
"chunk": ele.page_content,
|
349 |
+
"page_source": ele.metadata.get("source", "")
|
350 |
+
} for ele in retrieved_docs]
|
351 |
+
|
352 |
+
if not docs:
|
353 |
+
await safe_send(websocket, {"response": "Cannot provide answer to this question", "sources": []})
|
354 |
+
return
|
355 |
+
|
356 |
+
# Rerank the documents (fallback to original docs if reranking fails)
|
357 |
+
try:
|
358 |
+
ranked_docs = await asyncio.to_thread(rerank_docs, question, docs, pc)
|
359 |
+
except Exception as e:
|
360 |
+
logger.error(f"Reranking error: {e}")
|
361 |
+
ranked_docs = docs
|
362 |
+
|
363 |
+
# Prepare the conversation messages
|
364 |
+
messages = [{"role": "system", "content": system_prompt}]
|
365 |
+
messages.extend(chat_request.previous_chats)
|
366 |
+
messages.append({"role": "user", "content": format_query(question, language, ranked_docs)})
|
367 |
+
|
368 |
+
complete_answer = ""
|
369 |
+
chunk_buffer = ""
|
370 |
+
|
371 |
+
# Generate and stream the chat response
|
372 |
+
try:
|
373 |
+
completion = await openai_client.chat.completions.create(
|
374 |
+
model="gpt-4o",
|
375 |
+
messages=messages,
|
376 |
+
temperature=0.2,
|
377 |
+
max_completion_tokens=1024,
|
378 |
+
stream=True
|
379 |
+
)
|
380 |
+
async for chunk in completion:
|
381 |
+
delta_content = chunk.choices[0].delta.content
|
382 |
+
if delta_content:
|
383 |
+
complete_answer += delta_content
|
384 |
+
# Remove inline citation markers from the streamed chunk before sending
|
385 |
+
cleaned_content = re.sub(r'\[\d+\]', '', delta_content)
|
386 |
+
chunk_buffer += cleaned_content
|
387 |
+
if len(chunk_buffer) >= 1:
|
388 |
+
await safe_send(websocket, {"response": chunk_buffer})
|
389 |
+
chunk_buffer = ""
|
390 |
+
if chunk_buffer:
|
391 |
+
await safe_send(websocket, {"response": chunk_buffer})
|
392 |
+
except Exception as e:
|
393 |
+
logger.error(f"Streaming error: {e}")
|
394 |
+
await safe_send(websocket, {"response": "Response generation failed", "sources": []})
|
395 |
+
return
|
396 |
+
|
397 |
+
# Process and map citations in the final answer
|
398 |
+
complete_answer, citations = process_citations(complete_answer, ranked_docs)
|
399 |
+
|
400 |
+
await safe_send(websocket, {
|
401 |
+
"response": complete_answer,
|
402 |
+
"sources": citations
|
403 |
+
})
|
404 |
+
|
405 |
except WebSocketDisconnect:
|
406 |
+
logger.info("Client disconnected")
|
407 |
+
except Exception as e:
|
408 |
+
logger.error(f"Unexpected error: {e}")
|
409 |
+
await safe_send(websocket, {"response": "Something went wrong! Please try again.", "sources": []})
|
410 |
+
|
411 |
+
# ------------------------------------------------------------------------------
|
412 |
+
# Simple health check endpoint
|
413 |
+
# ------------------------------------------------------------------------------
|
414 |
+
@app.get("/")
|
415 |
+
async def root():
|
416 |
+
return JSONResponse(content={"message": "working"})
|