Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -127,30 +127,24 @@ def get_summarizer():
|
|
127 |
# if qa_model is None:
|
128 |
# qa_model= pipe = pipeline("question-answering", model="deepset/roberta-base-squad2")
|
129 |
#return qa_model
|
130 |
-
from transformers import RagTokenizer, RagTokenForGeneration, pipeline
|
131 |
|
132 |
-
|
133 |
-
|
134 |
|
135 |
def get_qa_model():
|
136 |
global qa_model
|
137 |
if qa_model is None:
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
142 |
return qa_model
|
143 |
|
144 |
-
def get_rag_model():
|
145 |
-
global rag_model
|
146 |
-
if rag_model is None:
|
147 |
-
rag_model = pipeline(
|
148 |
-
"text-generation",
|
149 |
-
model="facebook/rag-token-nq",
|
150 |
-
tokenizer="facebook/rag-token-nq"
|
151 |
-
)
|
152 |
-
return rag_model
|
153 |
-
|
154 |
|
155 |
|
156 |
|
@@ -161,7 +155,7 @@ def get_image_captioner():
|
|
161 |
return image_captioner
|
162 |
|
163 |
async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
|
164 |
-
"""
|
165 |
if not file.filename:
|
166 |
raise HTTPException(400, "No filename provided")
|
167 |
|
@@ -173,7 +167,6 @@ async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
|
|
173 |
if len(content) > MAX_FILE_SIZE:
|
174 |
raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
|
175 |
|
176 |
-
# Special validation for PDFs
|
177 |
if file_ext == "pdf":
|
178 |
try:
|
179 |
with fitz.open(stream=content, filetype="pdf") as doc:
|
@@ -186,14 +179,13 @@ async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
|
|
186 |
logger.error(f"PDF validation failed: {str(e)}")
|
187 |
raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}")
|
188 |
|
189 |
-
await file.seek(0)
|
190 |
return file_ext, content
|
191 |
|
192 |
def extract_text(content: bytes, file_ext: str) -> str:
|
193 |
-
"""
|
194 |
try:
|
195 |
if file_ext == "txt":
|
196 |
-
# Decode plain text (handle encoding issues)
|
197 |
return content.decode("utf-8", errors="replace").strip()
|
198 |
|
199 |
if file_ext == "docx":
|
@@ -201,7 +193,6 @@ def extract_text(content: bytes, file_ext: str) -> str:
|
|
201 |
return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
|
202 |
|
203 |
elif file_ext in {"xlsx", "xls"}:
|
204 |
-
# Improved Excel handling with better NaN and date support
|
205 |
df = pd.read_excel(
|
206 |
io.BytesIO(content),
|
207 |
sheet_name=None,
|
@@ -214,12 +205,9 @@ def extract_text(content: bytes, file_ext: str) -> str:
|
|
214 |
all_text = []
|
215 |
for sheet_name, sheet_data in df.items():
|
216 |
sheet_text = []
|
217 |
-
# Convert all data to string and handle special types
|
218 |
for column in sheet_data.columns:
|
219 |
-
# Handle datetime columns
|
220 |
if pd.api.types.is_datetime64_any_dtype(sheet_data[column]):
|
221 |
sheet_data[column] = sheet_data[column].dt.strftime('%Y-%m-%d %H:%M:%S')
|
222 |
-
# Convert to string and clean
|
223 |
col_text = sheet_data[column].astype(str).replace(['nan', 'None', 'NaT'], '').tolist()
|
224 |
sheet_text.extend([x for x in col_text if x.strip()])
|
225 |
|
@@ -241,14 +229,12 @@ def extract_text(content: bytes, file_ext: str) -> str:
|
|
241 |
return "\n".join(page.get_text("text") for page in pdf)
|
242 |
|
243 |
elif file_ext in {"jpg", "jpeg", "png"}:
|
244 |
-
# First try OCR
|
245 |
try:
|
246 |
image = Image.open(io.BytesIO(content))
|
247 |
text = pytesseract.image_to_string(image, config='--psm 6')
|
248 |
if text.strip():
|
249 |
return text
|
250 |
|
251 |
-
# If OCR fails, try image captioning
|
252 |
captioner = get_image_captioner()
|
253 |
result = captioner(image)
|
254 |
return result[0]['generated_text']
|
@@ -260,6 +246,19 @@ def extract_text(content: bytes, file_ext: str) -> str:
|
|
260 |
logger.error(f"Text extraction failed for {file_ext}: {str(e)}", exc_info=True)
|
261 |
raise HTTPException(422, f"Failed to extract text from {file_ext} file: {str(e)}")
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
# Visualization Models
|
264 |
class VisualizationRequest(BaseModel):
|
265 |
chart_type: str
|
@@ -833,61 +832,77 @@ async def summarize_document(request: Request, file: UploadFile = File(...)):
|
|
833 |
from typing import Optional
|
834 |
|
835 |
@app.post("/qa")
|
836 |
-
@limiter.limit("5/minute")
|
837 |
async def question_answering(
|
838 |
request: Request,
|
839 |
-
file: Optional[UploadFile] = File(None), # Make file optional
|
840 |
question: str = Form(...),
|
841 |
-
|
|
|
842 |
):
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
# (A) If file is provided and question is about it → Document QA
|
854 |
-
if file and is_doc_question:
|
855 |
-
try:
|
856 |
-
file_ext, content = await process_uploaded_file(file)
|
857 |
-
text = extract_text(content, file_ext)
|
858 |
-
text = re.sub(r'\s+', ' ', text).strip()[:5000]
|
859 |
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
873 |
|
874 |
-
|
875 |
-
else:
|
876 |
try:
|
877 |
-
|
878 |
-
|
879 |
|
880 |
-
return {
|
881 |
"question": question,
|
882 |
-
"answer":
|
883 |
-
"confidence": 0.8, # RAG doesn't return scores
|
884 |
"source": "general knowledge",
|
885 |
"language": language
|
886 |
-
}
|
|
|
887 |
except Exception as e:
|
888 |
-
logger.error(f"
|
889 |
-
raise HTTPException(500, "Failed to
|
890 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
891 |
|
892 |
|
893 |
|
|
|
127 |
# if qa_model is None:
|
128 |
# qa_model= pipe = pipeline("question-answering", model="deepset/roberta-base-squad2")
|
129 |
#return qa_model
|
|
|
130 |
|
131 |
+
|
132 |
+
|
133 |
|
134 |
def get_qa_model():
|
135 |
global qa_model
|
136 |
if qa_model is None:
|
137 |
+
try:
|
138 |
+
qa_model = pipeline(
|
139 |
+
"text2text-generation",
|
140 |
+
model="google/flan-t5-base",
|
141 |
+
device=0 if torch.cuda.is_available() else -1
|
142 |
+
)
|
143 |
+
except Exception as e:
|
144 |
+
logger.error(f"Failed to load QA model: {str(e)}")
|
145 |
+
raise HTTPException(500, "Failed to initialize QA system")
|
146 |
return qa_model
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
|
150 |
|
|
|
155 |
return image_captioner
|
156 |
|
157 |
async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
|
158 |
+
"""Your existing file processing function"""
|
159 |
if not file.filename:
|
160 |
raise HTTPException(400, "No filename provided")
|
161 |
|
|
|
167 |
if len(content) > MAX_FILE_SIZE:
|
168 |
raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
|
169 |
|
|
|
170 |
if file_ext == "pdf":
|
171 |
try:
|
172 |
with fitz.open(stream=content, filetype="pdf") as doc:
|
|
|
179 |
logger.error(f"PDF validation failed: {str(e)}")
|
180 |
raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}")
|
181 |
|
182 |
+
await file.seek(0)
|
183 |
return file_ext, content
|
184 |
|
185 |
def extract_text(content: bytes, file_ext: str) -> str:
|
186 |
+
"""Your existing text extraction function"""
|
187 |
try:
|
188 |
if file_ext == "txt":
|
|
|
189 |
return content.decode("utf-8", errors="replace").strip()
|
190 |
|
191 |
if file_ext == "docx":
|
|
|
193 |
return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
|
194 |
|
195 |
elif file_ext in {"xlsx", "xls"}:
|
|
|
196 |
df = pd.read_excel(
|
197 |
io.BytesIO(content),
|
198 |
sheet_name=None,
|
|
|
205 |
all_text = []
|
206 |
for sheet_name, sheet_data in df.items():
|
207 |
sheet_text = []
|
|
|
208 |
for column in sheet_data.columns:
|
|
|
209 |
if pd.api.types.is_datetime64_any_dtype(sheet_data[column]):
|
210 |
sheet_data[column] = sheet_data[column].dt.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
211 |
col_text = sheet_data[column].astype(str).replace(['nan', 'None', 'NaT'], '').tolist()
|
212 |
sheet_text.extend([x for x in col_text if x.strip()])
|
213 |
|
|
|
229 |
return "\n".join(page.get_text("text") for page in pdf)
|
230 |
|
231 |
elif file_ext in {"jpg", "jpeg", "png"}:
|
|
|
232 |
try:
|
233 |
image = Image.open(io.BytesIO(content))
|
234 |
text = pytesseract.image_to_string(image, config='--psm 6')
|
235 |
if text.strip():
|
236 |
return text
|
237 |
|
|
|
238 |
captioner = get_image_captioner()
|
239 |
result = captioner(image)
|
240 |
return result[0]['generated_text']
|
|
|
246 |
logger.error(f"Text extraction failed for {file_ext}: {str(e)}", exc_info=True)
|
247 |
raise HTTPException(422, f"Failed to extract text from {file_ext} file: {str(e)}")
|
248 |
|
249 |
+
|
250 |
+
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
|
262 |
# Visualization Models
|
263 |
class VisualizationRequest(BaseModel):
|
264 |
chart_type: str
|
|
|
832 |
from typing import Optional
|
833 |
|
834 |
@app.post("/qa")
|
|
|
835 |
async def question_answering(
|
836 |
request: Request,
|
|
|
837 |
question: str = Form(...),
|
838 |
+
file: Optional[UploadFile] = File(None),
|
839 |
+
language: str = Form("en")
|
840 |
):
|
841 |
+
"""
|
842 |
+
Enhanced QA endpoint that:
|
843 |
+
- Processes uploaded files using your existing functions
|
844 |
+
- Answers questions using FLAN-T5
|
845 |
+
- Handles both document and general knowledge questions
|
846 |
+
"""
|
847 |
+
try:
|
848 |
+
# Validate question
|
849 |
+
if not question.strip():
|
850 |
+
raise HTTPException(400, "Question cannot be empty")
|
|
|
|
|
|
|
|
|
|
|
|
|
851 |
|
852 |
+
qa_pipeline = get_qa_model()
|
853 |
+
|
854 |
+
# Case 1: Document QA (when file is provided)
|
855 |
+
if file:
|
856 |
+
try:
|
857 |
+
file_ext, content = await process_uploaded_file(file)
|
858 |
+
text = extract_text(content, file_ext)
|
859 |
+
|
860 |
+
# Clean and truncate text
|
861 |
+
clean_text = re.sub(r'\s+', ' ', text).strip()[:5000]
|
862 |
+
|
863 |
+
# Format for FLAN-T5 (combine question and context)
|
864 |
+
input_text = f"Answer this question based on the given context. Question: {question} Context: {clean_text}"
|
865 |
+
result = qa_pipeline(input_text, max_length=200)
|
866 |
+
|
867 |
+
return JSONResponse({
|
868 |
+
"question": question,
|
869 |
+
"answer": result[0]["generated_text"],
|
870 |
+
"source": "document",
|
871 |
+
"language": language,
|
872 |
+
"file_type": file_ext
|
873 |
+
})
|
874 |
+
|
875 |
+
except HTTPException:
|
876 |
+
raise
|
877 |
+
except Exception as e:
|
878 |
+
logger.error(f"Document QA failed: {str(e)}")
|
879 |
+
raise HTTPException(500, "Failed to analyze document")
|
880 |
|
881 |
+
# Case 2: General QA (no file provided)
|
|
|
882 |
try:
|
883 |
+
input_text = f"Answer this question: {question}"
|
884 |
+
result = qa_pipeline(input_text, max_length=200)
|
885 |
|
886 |
+
return JSONResponse({
|
887 |
"question": question,
|
888 |
+
"answer": result[0]["generated_text"],
|
|
|
889 |
"source": "general knowledge",
|
890 |
"language": language
|
891 |
+
})
|
892 |
+
|
893 |
except Exception as e:
|
894 |
+
logger.error(f"General QA failed: {str(e)}")
|
895 |
+
raise HTTPException(500, "Failed to generate answer")
|
896 |
|
897 |
+
except HTTPException:
|
898 |
+
raise
|
899 |
+
except Exception as e:
|
900 |
+
logger.critical(f"Unexpected error: {str(e)}")
|
901 |
+
raise HTTPException(500, "Internal server error")
|
902 |
+
|
903 |
+
|
904 |
+
|
905 |
+
|
906 |
|
907 |
|
908 |
|