chenguittiMaroua commited on
Commit
c6e8137
·
verified ·
1 Parent(s): 56f1984

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +155 -143
main.py CHANGED
@@ -1,10 +1,11 @@
1
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
  from transformers import pipeline
 
5
  import io
6
  import fitz # PyMuPDF
7
- from PIL import Image, UnidentifiedImageError
8
  import pandas as pd
9
  import uvicorn
10
  from docx import Document
@@ -12,8 +13,13 @@ from pptx import Presentation
12
  import pytesseract
13
  import logging
14
  import re
15
- from typing import Tuple
16
- import traceback
 
 
 
 
 
17
 
18
  # Configure logging
19
  logging.basicConfig(level=logging.INFO)
@@ -21,6 +27,10 @@ logger = logging.getLogger(__name__)
21
 
22
  app = FastAPI()
23
 
 
 
 
 
24
  # CORS Configuration
25
  app.add_middleware(
26
  CORSMiddleware,
@@ -32,176 +42,172 @@ app.add_middleware(
32
  # Constants
33
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
34
  SUPPORTED_FILE_TYPES = {
35
- "docx": "Word Document",
36
- "xlsx": "Excel Spreadsheet",
37
- "pptx": "PowerPoint",
38
- "pdf": "PDF",
39
- "jpg": "JPEG Image",
40
- "jpeg": "JPEG Image",
41
- "png": "PNG Image"
42
  }
43
 
44
- # Initialize models at startup
45
- try:
46
- logger.info("Loading ML models...")
47
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
48
- qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
49
- image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
50
- logger.info("Models loaded successfully")
51
- except Exception as e:
52
- logger.error(f"Failed to load models: {str(e)}")
53
- raise RuntimeError("Model initialization failed")
54
-
55
- async def validate_file(file: UploadFile) -> Tuple[str, bytes]:
56
- """Validate file type and size"""
 
 
 
 
 
 
 
 
 
 
 
 
57
  if not file.filename:
58
  raise HTTPException(400, "No filename provided")
59
 
60
  file_ext = file.filename.split('.')[-1].lower()
61
  if file_ext not in SUPPORTED_FILE_TYPES:
62
- raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES.values())}")
63
 
64
  content = await file.read()
65
  if len(content) > MAX_FILE_SIZE:
66
  raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
67
 
68
- await file.seek(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return file_ext, content
70
 
71
- def extract_text_from_pdf(content: bytes) -> str:
72
- """Extract text from PDF with error handling"""
73
- try:
74
- with fitz.open(stream=content, filetype="pdf") as doc:
75
- if doc.is_encrypted:
76
- if not doc.authenticate(""): # Try empty password
77
- raise ValueError("Encrypted PDF - cannot extract text")
78
- return "\n".join(page.get_text("text") for page in doc)
79
- except Exception as e:
80
- logger.error(f"PDF extraction failed: {str(e)}")
81
- raise ValueError(f"Failed to process PDF: {str(e)}")
82
-
83
- def extract_text_from_docx(content: bytes) -> str:
84
- """Extract text from Word document"""
85
- try:
86
- doc = Document(io.BytesIO(content))
87
- return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
88
- except Exception as e:
89
- logger.error(f"DOCX extraction failed: {str(e)}")
90
- raise ValueError("Failed to process Word document")
91
-
92
- def extract_text_from_excel(content: bytes) -> str:
93
- """Extract text from Excel (first sheet only)"""
94
- try:
95
- df = pd.read_excel(io.BytesIO(content), sheet_name=0)
96
- return "\n".join(df.iloc[:, 0].dropna().astype(str).tolist())
97
- except Exception as e:
98
- logger.error(f"Excel extraction failed: {str(e)}")
99
- raise ValueError("Failed to process Excel file")
100
-
101
- def extract_text_from_pptx(content: bytes) -> str:
102
- """Extract text from PowerPoint"""
103
  try:
104
- ppt = Presentation(io.BytesIO(content))
105
- return "\n".join(shape.text for slide in ppt.slides
106
- for shape in slide.shapes if hasattr(shape, "text"))
107
- except Exception as e:
108
- logger.error(f"PPTX extraction failed: {str(e)}")
109
- raise ValueError("Failed to process PowerPoint file")
110
-
111
- def extract_text_from_image(content: bytes) -> str:
112
- """Extract text from image using OCR or captioning"""
113
- try:
114
- image = Image.open(io.BytesIO(content))
115
 
116
- # First try OCR
117
- try:
118
- text = pytesseract.image_to_string(image, timeout=10) # 10 second timeout
119
- if text.strip():
120
- return text
121
- except Exception as ocr_error:
122
- logger.warning(f"OCR failed: {str(ocr_error)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- # Fallback to image captioning
125
- try:
126
- caption = image_captioner(image)[0]['generated_text']
127
- return f"Image description: {caption}"
128
- except Exception as caption_error:
129
- logger.error(f"Image captioning failed: {str(caption_error)}")
130
- raise ValueError("Could not process image")
131
-
132
- except UnidentifiedImageError:
133
- raise ValueError("Invalid image file")
134
  except Exception as e:
135
- logger.error(f"Image processing failed: {str(e)}")
136
- raise ValueError("Failed to process image")
137
-
138
- EXTRACTION_FUNCTIONS = {
139
- "pdf": extract_text_from_pdf,
140
- "docx": extract_text_from_docx,
141
- "xlsx": extract_text_from_excel,
142
- "pptx": extract_text_from_pptx,
143
- "jpg": extract_text_from_image,
144
- "jpeg": extract_text_from_image,
145
- "png": extract_text_from_image
146
- }
147
 
148
  @app.post("/summarize")
149
- async def summarize_document(file: UploadFile = File(...)):
 
150
  try:
151
- file_ext, content = await validate_file(file)
152
-
153
- # Get the appropriate extraction function
154
- extractor = EXTRACTION_FUNCTIONS.get(file_ext)
155
- if not extractor:
156
- raise HTTPException(400, "Unsupported file type")
157
 
158
- # Extract text
159
- text = extractor(content)
160
  if not text.strip():
161
  raise HTTPException(400, "No extractable text found")
162
 
163
- # Clean and summarize
164
- clean_text = re.sub(r'\s+', ' ', text).strip()[:3000] # Limit to 3000 chars
165
- summary = summarizer(clean_text, max_length=150, min_length=30, do_sample=False)[0]["summary_text"]
166
 
167
- return {"summary": summary}
 
 
 
 
 
168
 
169
- except HTTPException as he:
170
- raise he
171
- except ValueError as ve:
172
- logger.error(f"Processing error: {str(ve)}")
173
- raise HTTPException(422, detail=str(ve))
174
  except Exception as e:
175
- logger.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
176
- raise HTTPException(500, detail=f"Document processing failed: {str(e)}")
177
 
178
  @app.post("/qa")
 
179
  async def question_answering(
 
180
  file: UploadFile = File(...),
181
  question: str = Form(...),
182
  language: str = Form("fr")
183
  ):
184
  try:
185
- file_ext, content = await validate_file(file)
186
-
187
- # Get the appropriate extraction function
188
- extractor = EXTRACTION_FUNCTIONS.get(file_ext)
189
- if not extractor:
190
- raise HTTPException(400, "Unsupported file type")
191
 
192
- # Extract text
193
- text = extractor(content)
194
  if not text.strip():
195
  raise HTTPException(400, "No extractable text found")
196
-
197
- # Clean text
198
- clean_text = re.sub(r'\s+', ' ', text).strip()[:3000] # Limit to 3000 chars
199
-
200
- # Check for theme questions
201
  theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
202
  if any(kw in question.lower() for kw in theme_keywords):
203
  try:
204
- theme = summarizer(clean_text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
 
 
 
 
 
 
 
 
 
205
  return {
206
  "question": question,
207
  "answer": f"Le document traite principalement de : {theme}",
@@ -209,7 +215,7 @@ async def question_answering(
209
  "language": language
210
  }
211
  except Exception:
212
- theme = clean_text[:200] + ("..." if len(clean_text) > 200 else "")
213
  return {
214
  "question": question,
215
  "answer": f"D'après le document : {theme}",
@@ -217,24 +223,30 @@ async def question_answering(
217
  "language": language,
218
  "warning": "theme_summary_fallback"
219
  }
220
-
221
  # Standard QA
222
- result = qa_model(question=question, context=clean_text)
 
 
223
  return {
224
  "question": question,
225
  "answer": result["answer"],
226
  "confidence": result["score"],
227
  "language": language
228
  }
229
-
230
- except HTTPException as he:
231
- raise he
232
- except ValueError as ve:
233
- logger.error(f"Processing error: {str(ve)}")
234
- raise HTTPException(422, detail=str(ve))
235
  except Exception as e:
236
- logger.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
237
- raise HTTPException(500, detail=f"Question answering failed: {str(e)}")
 
 
 
 
 
 
 
238
 
239
  if __name__ == "__main__":
240
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
  from transformers import pipeline
5
+ from typing import Tuple
6
  import io
7
  import fitz # PyMuPDF
8
+ from PIL import Image
9
  import pandas as pd
10
  import uvicorn
11
  from docx import Document
 
13
  import pytesseract
14
  import logging
15
  import re
16
+ from slowapi import Limiter
17
+ from slowapi.util import get_remote_address
18
+ from slowapi.errors import RateLimitExceeded
19
+ from slowapi.middleware import SlowAPIMiddleware
20
+
21
+ # Initialize rate limiter
22
+ limiter = Limiter(key_func=get_remote_address)
23
 
24
  # Configure logging
25
  logging.basicConfig(level=logging.INFO)
 
27
 
28
  app = FastAPI()
29
 
30
+ # Apply rate limiting middleware
31
+ app.state.limiter = limiter
32
+ app.add_middleware(SlowAPIMiddleware)
33
+
34
  # CORS Configuration
35
  app.add_middleware(
36
  CORSMiddleware,
 
42
  # Constants
43
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
44
  SUPPORTED_FILE_TYPES = {
45
+ "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png"
 
 
 
 
 
 
46
  }
47
 
48
+ # Model caching
49
+ summarizer = None
50
+ qa_model = None
51
+ image_captioner = None
52
+
53
+ def get_summarizer():
54
+ global summarizer
55
+ if summarizer is None:
56
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
57
+ return summarizer
58
+
59
+ def get_qa_model():
60
+ global qa_model
61
+ if qa_model is None:
62
+ qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
63
+ return qa_model
64
+
65
+ def get_image_captioner():
66
+ global image_captioner
67
+ if image_captioner is None:
68
+ image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
69
+ return image_captioner
70
+
71
+ async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
72
+ """Validate and process uploaded file with special handling for each type"""
73
  if not file.filename:
74
  raise HTTPException(400, "No filename provided")
75
 
76
  file_ext = file.filename.split('.')[-1].lower()
77
  if file_ext not in SUPPORTED_FILE_TYPES:
78
+ raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}")
79
 
80
  content = await file.read()
81
  if len(content) > MAX_FILE_SIZE:
82
  raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
83
 
84
+ # Special validation for PDFs
85
+ if file_ext == "pdf":
86
+ try:
87
+ with fitz.open(stream=content, filetype="pdf") as doc:
88
+ if doc.is_encrypted:
89
+ if not doc.authenticate(""):
90
+ raise ValueError("Encrypted PDF - cannot extract text")
91
+ if len(doc) > 50:
92
+ raise ValueError("PDF too large (max 50 pages)")
93
+ except Exception as e:
94
+ logger.error(f"PDF validation failed: {str(e)}")
95
+ raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}")
96
+
97
+ await file.seek(0) # Reset file pointer for processing
98
  return file_ext, content
99
 
100
+ def extract_text(content: bytes, file_ext: str) -> str:
101
+ """Extract text from various file formats with enhanced support"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  try:
103
+ if file_ext == "docx":
104
+ doc = Document(io.BytesIO(content))
105
+ return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
 
 
 
 
 
 
 
 
106
 
107
+ elif file_ext in {"xlsx", "xls"}:
108
+ df = pd.read_excel(io.BytesIO(content), sheet_name=None)
109
+ all_text = []
110
+ for sheet_name, sheet_data in df.items():
111
+ sheet_text = []
112
+ for column in sheet_data.columns:
113
+ sheet_text.extend(sheet_data[column].dropna().astype(str).tolist())
114
+ all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text))
115
+ return "\n\n".join(all_text)
116
+
117
+ elif file_ext == "pptx":
118
+ ppt = Presentation(io.BytesIO(content))
119
+ text = []
120
+ for slide in ppt.slides:
121
+ for shape in slide.shapes:
122
+ if hasattr(shape, "text") and shape.text.strip():
123
+ text.append(shape.text)
124
+ return "\n".join(text)
125
+
126
+ elif file_ext == "pdf":
127
+ pdf = fitz.open(stream=content, filetype="pdf")
128
+ return "\n".join(page.get_text("text") for page in pdf)
129
+
130
+ elif file_ext in {"jpg", "jpeg", "png"}:
131
+ # First try OCR
132
+ try:
133
+ image = Image.open(io.BytesIO(content))
134
+ text = pytesseract.image_to_string(image, config='--psm 6')
135
+ if text.strip():
136
+ return text
137
+
138
+ # If OCR fails, try image captioning
139
+ captioner = get_image_captioner()
140
+ result = captioner(image)
141
+ return result[0]['generated_text']
142
+ except Exception as img_e:
143
+ logger.error(f"Image processing failed: {str(img_e)}")
144
+ raise ValueError("Could not extract text or caption from image")
145
 
 
 
 
 
 
 
 
 
 
 
146
  except Exception as e:
147
+ logger.error(f"Text extraction failed for {file_ext}: {str(e)}")
148
+ raise HTTPException(422, f"Failed to extract text from {file_ext} file")
 
 
 
 
 
 
 
 
 
 
149
 
150
  @app.post("/summarize")
151
+ @limiter.limit("5/minute")
152
+ async def summarize_document(request: Request, file: UploadFile = File(...)):
153
  try:
154
+ file_ext, content = await process_uploaded_file(file)
155
+ text = extract_text(content, file_ext)
 
 
 
 
156
 
 
 
157
  if not text.strip():
158
  raise HTTPException(400, "No extractable text found")
159
 
160
+ # Clean and chunk text
161
+ text = re.sub(r'\s+', ' ', text).strip()
162
+ chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
163
 
164
+ # Summarize each chunk
165
+ summarizer = get_summarizer()
166
+ summaries = []
167
+ for chunk in chunks:
168
+ summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
169
+ summaries.append(summary)
170
 
171
+ return {"summary": " ".join(summaries)}
172
+
173
+ except HTTPException:
174
+ raise
 
175
  except Exception as e:
176
+ logger.error(f"Summarization failed: {str(e)}")
177
+ raise HTTPException(500, "Document summarization failed")
178
 
179
  @app.post("/qa")
180
+ @limiter.limit("5/minute")
181
  async def question_answering(
182
+ request: Request,
183
  file: UploadFile = File(...),
184
  question: str = Form(...),
185
  language: str = Form("fr")
186
  ):
187
  try:
188
+ file_ext, content = await process_uploaded_file(file)
189
+ text = extract_text(content, file_ext)
 
 
 
 
190
 
 
 
191
  if not text.strip():
192
  raise HTTPException(400, "No extractable text found")
193
+
194
+ # Clean and truncate text
195
+ text = re.sub(r'\s+', ' ', text).strip()[:5000]
196
+
197
+ # Theme detection
198
  theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
199
  if any(kw in question.lower() for kw in theme_keywords):
200
  try:
201
+ summarizer = get_summarizer()
202
+ summary_output = summarizer(
203
+ text,
204
+ max_length=min(100, len(text)//4),
205
+ min_length=30,
206
+ do_sample=False,
207
+ truncation=True
208
+ )
209
+
210
+ theme = summary_output[0].get("summary_text", text[:200] + "...")
211
  return {
212
  "question": question,
213
  "answer": f"Le document traite principalement de : {theme}",
 
215
  "language": language
216
  }
217
  except Exception:
218
+ theme = text[:200] + ("..." if len(text) > 200 else "")
219
  return {
220
  "question": question,
221
  "answer": f"D'après le document : {theme}",
 
223
  "language": language,
224
  "warning": "theme_summary_fallback"
225
  }
226
+
227
  # Standard QA
228
+ qa = get_qa_model()
229
+ result = qa(question=question, context=text[:3000])
230
+
231
  return {
232
  "question": question,
233
  "answer": result["answer"],
234
  "confidence": result["score"],
235
  "language": language
236
  }
237
+
238
+ except HTTPException:
239
+ raise
 
 
 
240
  except Exception as e:
241
+ logger.error(f"QA processing failed: {str(e)}")
242
+ raise HTTPException(500, detail=f"Analysis failed: {str(e)}")
243
+
244
+ @app.exception_handler(RateLimitExceeded)
245
+ async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
246
+ return JSONResponse(
247
+ status_code=429,
248
+ content={"detail": "Too many requests. Please try again later."}
249
+ )
250
 
251
  if __name__ == "__main__":
252
  uvicorn.run(app, host="0.0.0.0", port=7860)