chenguittiMaroua commited on
Commit
94bcc5a
·
verified ·
1 Parent(s): dba5d7e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +85 -70
main.py CHANGED
@@ -127,30 +127,24 @@ def get_summarizer():
127
  # if qa_model is None:
128
  # qa_model= pipe = pipeline("question-answering", model="deepset/roberta-base-squad2")
129
  #return qa_model
130
- from transformers import RagTokenizer, RagTokenForGeneration, pipeline
131
 
132
- qa_model = None
133
- rag_model = None
134
 
135
  def get_qa_model():
136
  global qa_model
137
  if qa_model is None:
138
- qa_model = pipeline(
139
- "question-answering",
140
- model="deepset/roberta-base-squad2"
141
- )
 
 
 
 
 
142
  return qa_model
143
 
144
- def get_rag_model():
145
- global rag_model
146
- if rag_model is None:
147
- rag_model = pipeline(
148
- "text-generation",
149
- model="facebook/rag-token-nq",
150
- tokenizer="facebook/rag-token-nq"
151
- )
152
- return rag_model
153
-
154
 
155
 
156
 
@@ -161,7 +155,7 @@ def get_image_captioner():
161
  return image_captioner
162
 
163
  async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
164
- """Validate and process uploaded file with special handling for each type"""
165
  if not file.filename:
166
  raise HTTPException(400, "No filename provided")
167
 
@@ -173,7 +167,6 @@ async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
173
  if len(content) > MAX_FILE_SIZE:
174
  raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
175
 
176
- # Special validation for PDFs
177
  if file_ext == "pdf":
178
  try:
179
  with fitz.open(stream=content, filetype="pdf") as doc:
@@ -186,14 +179,13 @@ async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
186
  logger.error(f"PDF validation failed: {str(e)}")
187
  raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}")
188
 
189
- await file.seek(0) # Reset file pointer for processing
190
  return file_ext, content
191
 
192
  def extract_text(content: bytes, file_ext: str) -> str:
193
- """Extract text from various file formats with enhanced Excel support"""
194
  try:
195
  if file_ext == "txt":
196
- # Decode plain text (handle encoding issues)
197
  return content.decode("utf-8", errors="replace").strip()
198
 
199
  if file_ext == "docx":
@@ -201,7 +193,6 @@ def extract_text(content: bytes, file_ext: str) -> str:
201
  return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
202
 
203
  elif file_ext in {"xlsx", "xls"}:
204
- # Improved Excel handling with better NaN and date support
205
  df = pd.read_excel(
206
  io.BytesIO(content),
207
  sheet_name=None,
@@ -214,12 +205,9 @@ def extract_text(content: bytes, file_ext: str) -> str:
214
  all_text = []
215
  for sheet_name, sheet_data in df.items():
216
  sheet_text = []
217
- # Convert all data to string and handle special types
218
  for column in sheet_data.columns:
219
- # Handle datetime columns
220
  if pd.api.types.is_datetime64_any_dtype(sheet_data[column]):
221
  sheet_data[column] = sheet_data[column].dt.strftime('%Y-%m-%d %H:%M:%S')
222
- # Convert to string and clean
223
  col_text = sheet_data[column].astype(str).replace(['nan', 'None', 'NaT'], '').tolist()
224
  sheet_text.extend([x for x in col_text if x.strip()])
225
 
@@ -241,14 +229,12 @@ def extract_text(content: bytes, file_ext: str) -> str:
241
  return "\n".join(page.get_text("text") for page in pdf)
242
 
243
  elif file_ext in {"jpg", "jpeg", "png"}:
244
- # First try OCR
245
  try:
246
  image = Image.open(io.BytesIO(content))
247
  text = pytesseract.image_to_string(image, config='--psm 6')
248
  if text.strip():
249
  return text
250
 
251
- # If OCR fails, try image captioning
252
  captioner = get_image_captioner()
253
  result = captioner(image)
254
  return result[0]['generated_text']
@@ -260,6 +246,19 @@ def extract_text(content: bytes, file_ext: str) -> str:
260
  logger.error(f"Text extraction failed for {file_ext}: {str(e)}", exc_info=True)
261
  raise HTTPException(422, f"Failed to extract text from {file_ext} file: {str(e)}")
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  # Visualization Models
264
  class VisualizationRequest(BaseModel):
265
  chart_type: str
@@ -833,61 +832,77 @@ async def summarize_document(request: Request, file: UploadFile = File(...)):
833
  from typing import Optional
834
 
835
  @app.post("/qa")
836
- @limiter.limit("5/minute")
837
  async def question_answering(
838
  request: Request,
839
- file: Optional[UploadFile] = File(None), # Make file optional
840
  question: str = Form(...),
841
- language: str = Form("fr")
 
842
  ):
843
- # Validate question
844
- if not question.strip():
845
- raise HTTPException(400, "Question cannot be empty")
846
-
847
- # Check if the question is about the document
848
- is_doc_question = any(
849
- kw in question.lower()
850
- for kw in ["document", "file", "text", "this pdf", "this doc"]
851
- )
852
-
853
- # (A) If file is provided and question is about it → Document QA
854
- if file and is_doc_question:
855
- try:
856
- file_ext, content = await process_uploaded_file(file)
857
- text = extract_text(content, file_ext)
858
- text = re.sub(r'\s+', ' ', text).strip()[:5000]
859
 
860
- qa = get_qa_model()
861
- result = qa(question=question, context=text[:3000])
862
-
863
- return {
864
- "question": question,
865
- "answer": result["answer"],
866
- "confidence": result["score"],
867
- "source": "document",
868
- "language": language
869
- }
870
- except Exception as e:
871
- logger.error(f"Doc QA failed: {str(e)}")
872
- raise HTTPException(500, "Failed to analyze document")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
 
874
- # (B) If no file or general question → Open-domain QA (RAG)
875
- else:
876
  try:
877
- rag = get_rag_model()
878
- answer = rag(question)[0]["generated_text"]
879
 
880
- return {
881
  "question": question,
882
- "answer": answer,
883
- "confidence": 0.8, # RAG doesn't return scores
884
  "source": "general knowledge",
885
  "language": language
886
- }
 
887
  except Exception as e:
888
- logger.error(f"RAG failed: {str(e)}")
889
- raise HTTPException(500, "Failed to fetch general answer")
890
 
 
 
 
 
 
 
 
 
 
891
 
892
 
893
 
 
127
  # if qa_model is None:
128
  # qa_model= pipe = pipeline("question-answering", model="deepset/roberta-base-squad2")
129
  #return qa_model
 
130
 
131
+
132
+
133
 
134
  def get_qa_model():
135
  global qa_model
136
  if qa_model is None:
137
+ try:
138
+ qa_model = pipeline(
139
+ "text2text-generation",
140
+ model="google/flan-t5-base",
141
+ device=0 if torch.cuda.is_available() else -1
142
+ )
143
+ except Exception as e:
144
+ logger.error(f"Failed to load QA model: {str(e)}")
145
+ raise HTTPException(500, "Failed to initialize QA system")
146
  return qa_model
147
 
 
 
 
 
 
 
 
 
 
 
148
 
149
 
150
 
 
155
  return image_captioner
156
 
157
  async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
158
+ """Your existing file processing function"""
159
  if not file.filename:
160
  raise HTTPException(400, "No filename provided")
161
 
 
167
  if len(content) > MAX_FILE_SIZE:
168
  raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
169
 
 
170
  if file_ext == "pdf":
171
  try:
172
  with fitz.open(stream=content, filetype="pdf") as doc:
 
179
  logger.error(f"PDF validation failed: {str(e)}")
180
  raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}")
181
 
182
+ await file.seek(0)
183
  return file_ext, content
184
 
185
  def extract_text(content: bytes, file_ext: str) -> str:
186
+ """Your existing text extraction function"""
187
  try:
188
  if file_ext == "txt":
 
189
  return content.decode("utf-8", errors="replace").strip()
190
 
191
  if file_ext == "docx":
 
193
  return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
194
 
195
  elif file_ext in {"xlsx", "xls"}:
 
196
  df = pd.read_excel(
197
  io.BytesIO(content),
198
  sheet_name=None,
 
205
  all_text = []
206
  for sheet_name, sheet_data in df.items():
207
  sheet_text = []
 
208
  for column in sheet_data.columns:
 
209
  if pd.api.types.is_datetime64_any_dtype(sheet_data[column]):
210
  sheet_data[column] = sheet_data[column].dt.strftime('%Y-%m-%d %H:%M:%S')
 
211
  col_text = sheet_data[column].astype(str).replace(['nan', 'None', 'NaT'], '').tolist()
212
  sheet_text.extend([x for x in col_text if x.strip()])
213
 
 
229
  return "\n".join(page.get_text("text") for page in pdf)
230
 
231
  elif file_ext in {"jpg", "jpeg", "png"}:
 
232
  try:
233
  image = Image.open(io.BytesIO(content))
234
  text = pytesseract.image_to_string(image, config='--psm 6')
235
  if text.strip():
236
  return text
237
 
 
238
  captioner = get_image_captioner()
239
  result = captioner(image)
240
  return result[0]['generated_text']
 
246
  logger.error(f"Text extraction failed for {file_ext}: {str(e)}", exc_info=True)
247
  raise HTTPException(422, f"Failed to extract text from {file_ext} file: {str(e)}")
248
 
249
+
250
+
251
+
252
+
253
+
254
+
255
+
256
+
257
+
258
+
259
+
260
+
261
+
262
  # Visualization Models
263
  class VisualizationRequest(BaseModel):
264
  chart_type: str
 
832
  from typing import Optional
833
 
834
  @app.post("/qa")
 
835
  async def question_answering(
836
  request: Request,
 
837
  question: str = Form(...),
838
+ file: Optional[UploadFile] = File(None),
839
+ language: str = Form("en")
840
  ):
841
+ """
842
+ Enhanced QA endpoint that:
843
+ - Processes uploaded files using your existing functions
844
+ - Answers questions using FLAN-T5
845
+ - Handles both document and general knowledge questions
846
+ """
847
+ try:
848
+ # Validate question
849
+ if not question.strip():
850
+ raise HTTPException(400, "Question cannot be empty")
 
 
 
 
 
 
851
 
852
+ qa_pipeline = get_qa_model()
853
+
854
+ # Case 1: Document QA (when file is provided)
855
+ if file:
856
+ try:
857
+ file_ext, content = await process_uploaded_file(file)
858
+ text = extract_text(content, file_ext)
859
+
860
+ # Clean and truncate text
861
+ clean_text = re.sub(r'\s+', ' ', text).strip()[:5000]
862
+
863
+ # Format for FLAN-T5 (combine question and context)
864
+ input_text = f"Answer this question based on the given context. Question: {question} Context: {clean_text}"
865
+ result = qa_pipeline(input_text, max_length=200)
866
+
867
+ return JSONResponse({
868
+ "question": question,
869
+ "answer": result[0]["generated_text"],
870
+ "source": "document",
871
+ "language": language,
872
+ "file_type": file_ext
873
+ })
874
+
875
+ except HTTPException:
876
+ raise
877
+ except Exception as e:
878
+ logger.error(f"Document QA failed: {str(e)}")
879
+ raise HTTPException(500, "Failed to analyze document")
880
 
881
+ # Case 2: General QA (no file provided)
 
882
  try:
883
+ input_text = f"Answer this question: {question}"
884
+ result = qa_pipeline(input_text, max_length=200)
885
 
886
+ return JSONResponse({
887
  "question": question,
888
+ "answer": result[0]["generated_text"],
 
889
  "source": "general knowledge",
890
  "language": language
891
+ })
892
+
893
  except Exception as e:
894
+ logger.error(f"General QA failed: {str(e)}")
895
+ raise HTTPException(500, "Failed to generate answer")
896
 
897
+ except HTTPException:
898
+ raise
899
+ except Exception as e:
900
+ logger.critical(f"Unexpected error: {str(e)}")
901
+ raise HTTPException(500, "Internal server error")
902
+
903
+
904
+
905
+
906
 
907
 
908